-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathWORKSPACE
99 lines (81 loc) · 3.34 KB
/
WORKSPACE
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright 2024 The JAX SC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
workspace(name = "jax_tpu_embedding")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
###############################################################################
## SparseCore Dependencies
###############################################################################
HIGHWAY_VERSION= "1.2.0"
HIGHWAY_SHA256 = "7e0be78b8318e8bdbf6fa545d2ecb4c90f947df03f7aadc42c1967f019e63343"
HIGHWAY_ARCHIVE = "https://github.com/google/highway/archive/{version}.tar.gz".format(version = HIGHWAY_VERSION)
http_archive(
name = "highway",
sha256 = HIGHWAY_SHA256,
strip_prefix = "highway-{version}".format(version = HIGHWAY_VERSION),
urls = [HIGHWAY_ARCHIVE],
)
# rules_license come _before_ XLA, since highway requires a newer version.
maybe(
http_archive,
name = "rules_license",
urls = [
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
],
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
)
###############################################################################
## XLA Initialization
###############################################################################
# This is adapted from JAX's WORKSPACE file.
# The XLA commit is determined by external/xla/workspace.bzl.
load("//third_party/xla:workspace.bzl", xla_repo = "repo")
xla_repo()
# Initialize hermetic Python
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
python_init_rules()
load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(
requirements = {
"3.10": "//build:requirements_lock_3_10.txt",
"3.11": "//build:requirements_lock_3_11.txt",
"3.12": "//build:requirements_lock_3_12.txt",
"3.13": "//build:requirements_lock_3_13.txt",
},
default_python_version = "system",
)
load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
python_init_toolchains()
load("//third_party/bazel/python:python_init_pip.bzl", "python_init_pip")
python_init_pip()
load("@pypi//:requirements.bzl", "install_deps")
install_deps()
# Load all XLA dependencies.
load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()
load("@xla//:workspace3.bzl", "xla_workspace3")
xla_workspace3()
load("@xla//:workspace2.bzl", "xla_workspace2")
xla_workspace2()
load("@xla//:workspace1.bzl", "xla_workspace1")
xla_workspace1()
load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
# Even though we don't use CUDA, this is required since it is needed
# by TSL, one of our dependencies.
cuda_configure(name = "local_config_cuda")