diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..7cddb6ff7bbd --- /dev/null +++ b/BUILD @@ -0,0 +1,93 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) +load("//jax:py_deps.bzl", "transitive_py_deps") +load( + "//jaxlib:jax.bzl", + "AARCH64_MANYLINUX_TAG", + "PPC64LE_MANYLINUX_TAG", + "X86_64_MANYLINUX_TAG", + "jax_wheel", +) + +transitive_py_deps( + name = "transitive_py_deps", + deps = [ + "//jax", + "//jax:experimental", + "//jax:experimental_libs", + "//jax:lax_reference", + "//jax:pallas_mosaic_gpu", + "//jax:tpu_custom_call", + "//jax/_src/lib", + "//jax/_src/pallas/mosaic:libs", + "//jax/_src/pallas/mosaic_gpu", + "//jax/_src/pallas/triton:libs", + "//jax/experimental/jax2tf", + "//jax/experimental/jax2tf:jax2tf_libs", + "//jax/experimental/pallas/ops/gpu:libs", + "//jax/extend", + "//jax/extend:ifrt_programs", + "//jax/extend/mlir:libs", + "//jax/extend/mlir/dialects:libs", + "//jax/tools:libs", + ], +) + +py_binary( + name = "build_wheel", + srcs = ["build_wheel.py"], + deps = [ + "//jaxlib/tools:build_utils", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +jax_wheel( + name = "jax_wheel", + no_abi = True, + no_platform = True, + source_files = [ + ":transitive_py_deps", + "//jax:py.typed", + "//jax:numpy/__init__.pyi", + "//jax:_src/basearray.pyi", + "//jax:test_util_sources", + "//jax:internal_test_util_sources", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +verify_manylinux_compliance_test( + name = "jax_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jax_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) diff --git a/build_wheel.py b/build_wheel.py new file mode 100644 index 000000000000..b1f30da8ce02 --- /dev/null +++ b/build_wheel.py @@ -0,0 +1,100 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script that builds a JAX wheel, intended to be run via bazel run as part +# of the JAX build process. + +import argparse +import os +import pathlib +import shutil +import tempfile + +from jaxlib.tools import build_utils + +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") +parser.add_argument( + "--sources_path", + default=None, + help=( + "Path in which the wheel's sources should be prepared. Optional. If " + "omitted, a temporary directory will be used." + ), +) +parser.add_argument( + "--output_path", + default=None, + required=True, + help="Path to which the output wheel should be written. Required.", +) +parser.add_argument( + "--jaxlib_git_hash", + default="", + required=True, + help="Git hash. Empty if unknown. Optional.", +) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) +args = parser.parse_args() + + +def copy_file( + src_file: str, + dst_dir: str, +) -> None: + """Copy a file to the destination directory. + + Args: + src_file: file to be copied + dst_dir: destination directory + """ + + dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file)) + os.makedirs(dest_dir_path, exist_ok=True) + shutil.copy(src_file, dest_dir_path) + os.chmod(os.path.join(dst_dir, src_file), 0o644) + + +def prepare_srcs(deps: list[str], srcs_dir: str) -> None: + """Rearrange source files in target the target directory. + + Args: + deps: a list of paths to files. + srcs_dir: target directory where files are copied to. + """ + + for file in deps: + if not (file.startswith("bazel-out") or file.startswith("external")): + copy_file(file, srcs_dir) + + +tmpdir = None +sources_path = args.sources_path +if sources_path is None: + tmpdir = tempfile.TemporaryDirectory(prefix="jax") + sources_path = tmpdir.name + +try: + os.makedirs(args.output_path, exist_ok=True) + prepare_srcs(args.srcs, pathlib.Path(sources_path)) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name="jax", + git_hash=args.jaxlib_git_hash, + ) +finally: + if tmpdir: + tmpdir.cleanup() diff --git a/jax/BUILD b/jax/BUILD index 6eda8311aa55..8c245beed201 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -61,6 +61,9 @@ config_setting( exports_files([ "LICENSE", "version.py", + "py.typed", + "numpy/__init__.pyi", + "_src/basearray.pyi", ]) exports_files( @@ -110,6 +113,37 @@ package_group( packages = mosaic_gpu_internal_users, ) +py_library( + name = "experimental_libs", + srcs = glob([ + "experimental/**/*.py", + "experimental/pallas/ops/gpu/**/*.py", + ]), +) + +filegroup( + name = "test_util_sources", + srcs = [ + "_src/test_util.py", + "_src/test_warning_util.py", + ], +) + +filegroup( + name = "internal_test_util_sources", + srcs = [ + "_src/internal_test_util/__init__.py", + "_src/internal_test_util/deprecation_module.py", + "_src/internal_test_util/export_back_compat_test_util.py", + "_src/internal_test_util/lax_test_util.py", + "_src/internal_test_util/test_harnesses.py", + ] + glob( + [ + "_src/internal_test_util/lazy_loader_module/*.py", + ], + ), +) + # JAX-private test utilities. py_library( # This build target is required in order to use private test utilities in jax._src.test_util, @@ -118,10 +152,7 @@ py_library( # these are available in jax.test_util via the standard :jax target. name = "test_util", testonly = 1, - srcs = [ - "_src/test_util.py", - "_src/test_warning_util.py", - ], + srcs = [":test_util_sources"], visibility = [ ":internal", ] + jax_test_util_visibility, diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index d239fba98bc7..960e9aa4ba87 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -24,6 +24,23 @@ package( ], ) +py_library( + name = "libs", + srcs = ["__init__.py"], + deps = [ + ":core", + ":error_handling", + ":helpers", + ":interpret", + ":lowering", + ":pallas_call_registration", + ":pipeline", + ":primitives", + ":random", + ":verification", + ], +) + py_library( name = "core", srcs = ["core.py"], diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 84fae3913491..e3f5a72f72ad 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -27,6 +27,17 @@ package( ], ) +pytype_strict_library( + name = "libs", + srcs = ["__init__.py"], + deps = [ + ":core", + ":lowering", + ":pallas_call_registration", + ":primitives", + ], +) + pytype_strict_library( name = "core", srcs = ["core.py"], diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index 85ad90326859..77086376974e 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -27,6 +27,15 @@ package( default_visibility = ["//visibility:private"], ) +py_library( + name = "jax2tf_libs", + srcs = glob([ + "examples/**/*.py", + "tests/**/*.py", + ]), + visibility = ["//visibility:public"], +) + py_library( name = "jax2tf", srcs = ["__init__.py"], diff --git a/jax/experimental/pallas/ops/gpu/BUILD b/jax/experimental/pallas/ops/gpu/BUILD index 20ff2152c356..8cbaa5e7072d 100644 --- a/jax/experimental/pallas/ops/gpu/BUILD +++ b/jax/experimental/pallas/ops/gpu/BUILD @@ -21,6 +21,11 @@ exports_files( srcs = glob(["*.py"]), ) +py_library( + name = "libs", + srcs = glob(["*.py"]), +) + filegroup( name = "triton_ops", srcs = glob( diff --git a/jax/extend/mlir/BUILD b/jax/extend/mlir/BUILD index 8b8304282da0..a6f907950df4 100644 --- a/jax/extend/mlir/BUILD +++ b/jax/extend/mlir/BUILD @@ -23,6 +23,15 @@ package( default_visibility = ["//jax:jax_extend_users"], ) +pytype_strict_library( + name = "libs", + srcs = ["__init__.py"], + deps = [ + ":ir", + ":pass_manager", + ], +) + pytype_strict_library( name = "ir", srcs = ["ir.py"], diff --git a/jax/extend/mlir/dialects/BUILD b/jax/extend/mlir/dialects/BUILD index 7bd9e95b0175..fcf82ce5eee8 100644 --- a/jax/extend/mlir/dialects/BUILD +++ b/jax/extend/mlir/dialects/BUILD @@ -23,6 +23,24 @@ package( default_visibility = ["//jax:jax_extend_users"], ) +pytype_strict_library( + name = "libs", + srcs = ["__init__.py"], + deps = [ + ":arithmetic_dialect", + ":builtin_dialect", + ":chlo_dialect", + ":func_dialect", + ":math_dialect", + ":memref_dialect", + ":scf_dialect", + ":sdy_dialect", + ":sparse_tensor_dialect", + ":stablehlo_dialect", + ":vector_dialect", + ], +) + pytype_strict_library( name = "arithmetic_dialect", srcs = ["arith.py"], diff --git a/jax/py_deps.bzl b/jax/py_deps.bzl new file mode 100644 index 000000000000..60eef58e5cb8 --- /dev/null +++ b/jax/py_deps.bzl @@ -0,0 +1,40 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rule for collecting python files that a target depends on. + +It traverses dependencies of provided targets, collect their direct and transitive python deps and +then return a list of paths to files. +""" + +def _transitive_py_deps_impl(ctx): + outputs = depset( + [], + transitive = [dep[PyInfo].transitive_sources for dep in ctx.attr.deps], + ) + return DefaultInfo(files = outputs) + +_transitive_py_deps = rule( + attrs = { + "deps": attr.label_list( + allow_files = True, + providers = [PyInfo], + ), + }, + implementation = _transitive_py_deps_impl, +) + +def transitive_py_deps(name, deps = []): + _transitive_py_deps(name = name + "_gather", deps = deps) + native.filegroup(name = name, srcs = [":" + name + "_gather"]) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 3e0a950292a3..b46346c2146c 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -25,6 +25,11 @@ package( default_visibility = ["//visibility:public"], ) +py_library( + name = "libs", + srcs = glob(["**/*.py"]), +) + py_library( name = "jax_to_ir", srcs = ["jax_to_ir.py"], diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e42f1e311931..a4bca77f5b49 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -62,6 +62,12 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } +AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) + +PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) + +X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")]) + # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13": @@ -321,7 +327,7 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) -def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version): +def _get_full_wheel_name(package_name, no_abi, no_platform, platform_name, cpu_name, wheel_version): if no_abi: wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" else: @@ -332,7 +338,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve python_version = python_version, major_python_version = python_version[0], wheel_version = wheel_version, - wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]), + wheel_platform_tag = "any" if no_platform else "_".join( + PLATFORM_TAGS_DICT[platform_name, cpu_name], + ), ) def _jax_wheel_impl(ctx): @@ -360,10 +368,13 @@ def _jax_wheel_impl(ctx): env["JAX_RELEASE"] = "1" cpu = ctx.attr.cpu + no_abi = ctx.attr.no_abi + no_platform = ctx.attr.no_platform platform_name = ctx.attr.platform_name wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, - no_abi = ctx.attr.no_abi, + no_abi = no_abi, + no_platform = no_platform, platform_name = platform_name, cpu_name = cpu, wheel_version = full_wheel_version, @@ -373,7 +384,8 @@ def _jax_wheel_impl(ctx): wheel_dir = output_file.path[:output_file.path.rfind("/")] args.add("--output_path", wheel_dir) # required argument - args.add("--cpu", cpu) # required argument + if not no_platform: + args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument if ctx.attr.enable_cuda: @@ -389,11 +401,17 @@ def _jax_wheel_impl(ctx): if ctx.attr.skip_gpu_kernels: args.add("--skip_gpu_kernels") + srcs = [] + for src in ctx.attr.source_files: + for f in src.files.to_list(): + srcs.append(f) + args.add("--srcs=%s" % (f.path)) + args.set_param_file_format("flag_per_line") args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [], + inputs = srcs, outputs = [output_file], executable = executable, env = env, @@ -411,9 +429,11 @@ _jax_wheel = rule( ), "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), + "no_platform": attr.bool(default = False), "cpu": attr.string(mandatory = True), "platform_name": attr.string(mandatory = True), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), + "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. @@ -427,7 +447,15 @@ _jax_wheel = rule( executable = False, ) -def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""): +def jax_wheel( + name, + wheel_binary, + wheel_name, + no_abi = False, + no_platform = False, + enable_cuda = False, + platform_version = "", + source_files = []): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -437,8 +465,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI + no_platform: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel + source_files: the source files to include in the wheel Returns: A directory containing the wheel @@ -448,6 +478,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals wheel_binary = wheel_binary, wheel_name = wheel_name, no_abi = no_abi, + no_platform = no_platform, enable_cuda = enable_cuda, platform_version = platform_version, # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` @@ -465,6 +496,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals "//jaxlib/tools:arm64": "aarch64", "@platforms//cpu:x86_64": "x86_64", }), + source_files = source_files, ) jax_test_file_visibility = [] diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 3188463817c8..94dcfe984319 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -24,7 +24,10 @@ load( ) load( "//jaxlib:jax.bzl", + "AARCH64_MANYLINUX_TAG", "PLATFORM_TAGS_DICT", + "PPC64LE_MANYLINUX_TAG", + "X86_64_MANYLINUX_TAG", "if_windows", "jax_py_test", "jax_wheel", @@ -231,12 +234,6 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) -AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) - -PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) - -X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")]) - verify_manylinux_compliance_test( name = "jaxlib_manylinux_compliance_test", aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,