Skip to content

Commit

Permalink
Test new approach.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699315679
  • Loading branch information
Google-ML-Automation committed Nov 27, 2024
1 parent df6758f commit e6a2100
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 84 deletions.
5 changes: 3 additions & 2 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
# This configuration is used for building the wheels.
build:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
Expand Down
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load("@tsl//third_party/py:python_wheel_version.bzl", "python_wheel_version_repository")
python_wheel_version_repository(
name = "jax_wheel_version",
file_with_version = "//jax:version.py",
version_key = "_version",
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
Expand Down
1 change: 1 addition & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ async def main():

if "cuda" in wheel:
wheel_build_command.append("--config=cuda")
wheel_build_command.append("--config=cuda_wheel")
wheel_build_command.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
)
Expand Down
5 changes: 5 additions & 0 deletions jax/tools/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str:
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"


def build_tag() -> str:
return os.getenv("WHEEL_BUILD_TAG")


def get_githash(jaxlib_git_hash):
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
with open(jaxlib_git_hash, "r") as f:
Expand Down
16 changes: 13 additions & 3 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.environ.get("WHEEL_BUILD_TAG"):
return _version
if os.environ.get("WHEEL_VERSION_SUFFIX"):
return _version + os.environ.get("WHEEL_VERSION_SUFFIX")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down Expand Up @@ -77,10 +81,16 @@ def _get_version_for_build() -> str:
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if (
os.environ.get("JAX_RELEASE")
or os.environ.get("JAXLIB_RELEASE")
or os.environ.get("WHEEL_BUILD_TAG")
):
return _version
if os.environ.get("WHEEL_VERSION_SUFFIX"):
return _version + os.environ.get("WHEEL_VERSION_SUFFIX")
if os.environ.get("JAX_NIGHTLY") or os.environ.get("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down
207 changes: 176 additions & 31 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Bazel macros used by the JAX build."""

load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
Expand Down Expand Up @@ -308,38 +310,163 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)

def _get_wheel_platform_name(platform_name, cpu_name):
platform = ""
cpu = ""
if platform_name == "linux":
platform = "manylinux2014"
cpu = cpu_name
elif platform_name == "macosx":
if cpu_name == "arm64":
cpu = "arm64"
platform = "macosx_11_0"
else:
cpu = "x86_64"
platform = "macosx_10_14"
elif platform_name == "win":
platform = "win"
cpu = "amd64"
return "{platform}_{cpu}".format(
platform = platform,
cpu = cpu,
)

def _get_cpu(platform_name, platform_tag):
# Following the convention in jax/tools/build_utils.py.
if platform_name == "macosx" and platform_tag == "arm64":
return "arm64"
if platform_name == "win" and platform_tag == "x86_64":
return "AMD64"
return "aarch64" if platform_tag == "arm64" else platform_tag

def _get_full_wheel_name(rule_name, platform_name, cpu_name, major_cuda_version, wheel_version):
if "pjrt" in rule_name:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
package_name = rule_name.replace("_wheel", "").replace(
"cuda",
"cuda{}".format(major_cuda_version),
)
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name),
)

def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
wheel_type = ctx.attr.wheel_type[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
custom_version_suffix = ctx.attr.custom_version_suffix[BuildSettingInfo].value
build_date = ctx.attr.build_date[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
verify_manylinux = ctx.attr.verify_manylinux[BuildSettingInfo].value

executable = ctx.executable.wheel_binary
full_wheel_version = ctx.attr.wheel_version

if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built with CUDA dependencies." +
" Please provide `--config=cuda_wheel` for bazel build command." +
" If you absolutely need to add CUDA dependencies, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")

env = {}
if wheel_type == "nightly":
if not build_date:
fail("--//jaxlib/tools:build_date is required for nightly builds!")
env["JAXLIB_NIGHTLY"] = "1"
formatted_date = build_date.replace("-", "")
env["WHEEL_VERSION_SUFFIX"] = ".dev{}".format(formatted_date)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
elif wheel_type == "release":
env["JAXLIB_RELEASE"] = "1"
elif build_date:
formatted_date = build_date.replace("-", "")
formatted_hash = git_hash[:9]
if git_hash:
env["WHEEL_VERSION_SUFFIX"] = ".dev{date}+{hash}".format(
date = formatted_date,
hash = formatted_hash,
)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
else:
env["WHEEL_VERSION_SUFFIX"] = ".dev{}".format(formatted_date)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
if custom_version_suffix:
env["WHEEL_VERSION_SUFFIX"] = "{version_suffix}.{custom_version_suffix}".format(
version_suffix = env["WHEEL_VERSION_SUFFIX"],
custom_version_suffix = custom_version_suffix,
)
full_wheel_version += ".{}".format(custom_version_suffix)
else:
env["WHEEL_BUILD_TAG"] = "0"
full_wheel_version += "-0"

cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag)
wheel_name = _get_full_wheel_name(
ctx.label.name,
ctx.attr.platform_name,
cpu,
ctx.attr.platform_version,
full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]

output = ctx.actions.declare_directory(ctx.label.name)
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
args = {
"--output_path": wheel_dir, # required argument
"--cpu": cpu, # required argument
}
args["--jaxlib_git_hash"] = "\"{}\"".format(git_hash) # required argument

if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
args["--enable-cuda"] = "True"
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid cuda version for cuda wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
args["--platform_version"] = ctx.attr.platform_version # required for gpu wheels
if ctx.attr.enable_rocm:
args.add("--enable-rocm", "True")
args["--enable-rocm"] = "True"
if ctx.attr.platform_version == "":
fail("platform_version must be set to a valid rocm version for rocm wheels")
args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels
args["--platform_version"] = ctx.attr.platform_version # required for gpu wheels
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")

args.set_param_file_format("flag_per_line")
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
executable = executable,
args["--skip_gpu_kernels"] = "True"
args_as_string = ""
for arg in args:
args_as_string += "{arg} {val} ".format(arg = arg, val = args[arg])

ctx.actions.run_shell(
inputs = [],
command = executable.path + " " + args_as_string,
outputs = [output_file],
tools = [executable],
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]

auditwheel_show_log = None
if ctx.attr.platform_name == "linux":
auditwheel_show_log = ctx.actions.declare_file("auditwheel_show.log")
args = ctx.actions.args()
args.add("--wheel_path", output_file.path)
if verify_manylinux:
args.add("--compliance-tag", ctx.attr.manylinux_compliance_tag)
args.add("--auditwheel-show-log-path", auditwheel_show_log.path)
ctx.actions.run(
arguments = [args],
inputs = [output_file],
outputs = [auditwheel_show_log],
executable = ctx.executable.verify_manylinux_compliance_binary,
)

auditwheel_show_output = [auditwheel_show_log] if auditwheel_show_log else []
return [DefaultInfo(files = depset(direct = [output_file] + auditwheel_show_output))]

_jax_wheel = rule(
attrs = {
Expand All @@ -350,12 +477,27 @@ _jax_wheel = rule(
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:git_hash")),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"wheel_type": attr.label(default = Label("//jaxlib/tools:wheel_type")),
"custom_version_suffix": attr.label(default = Label("//jaxlib/tools:custom_version_suffix")),
"build_date": attr.label(default = Label("//jaxlib/tools:build_date")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
"wheel_version": attr.string(default = WHEEL_VERSION),
"verify_manylinux_compliance_binary": attr.label(
default = Label("@tsl//third_party/py:verify_manylinux_compliance"),
executable = True,
cfg = "exec",
),
"verify_manylinux": attr.label(default = Label("@tsl//third_party/py:verify_manylinux")),
"manylinux_compliance_tag": attr.string(mandatory = True),
},
implementation = _jax_wheel_impl,
executable = False,
Expand All @@ -380,21 +522,24 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
wheel_binary = wheel_binary,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
platform_name = select({
"@platforms//os:osx": "macosx",
"@platforms//os:macos": "macosx",
"@platforms//os:windows": "win",
"@platforms//os:linux": "linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:aarch64": "arm64",
"@platforms//cpu:arm64": "arm64",
"@platforms//cpu:x86_64": "x86_64",
}),
manylinux_compliance_tag = select({
"@platforms//cpu:aarch64": "manylinux_2_17_aarch64",
"@platforms//cpu:arm64": "manylinux_2_17_aarch64",
"@platforms//cpu:x86_64": "manylinux_2_17_x86_64",
"//conditions:default": "",
}),
)

jax_test_file_visibility = []
Expand Down
Loading

0 comments on commit e6a2100

Please sign in to comment.