Skip to content

Commit

Permalink
Add unstack_and_unshard for SparseCore
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714073293
  • Loading branch information
ChromeHearts authored and The sparsecore Authors committed Jan 10, 2025
1 parent 48edd8d commit 8fd28a4
Show file tree
Hide file tree
Showing 21 changed files with 1,367 additions and 23 deletions.
5 changes: 5 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:clang --copt=-Qunused-arguments

##############################################################################
# Test configurations.
##############################################################################
test:cpu --test_env=JAX_PLATFORMS=cpu --test_tag_filters=cpu

#############################################################################
# Some configs to make getting some forms of debug builds. In general, the
# codebase is only regularly built with optimizations. Use 'debug_symbols' to
Expand Down
73 changes: 73 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: Build and test

on:
# Only run workflow on pushes to main (includes PR merge), and on
# opened pull-requests.
push:
branches:
- main
pull_request:

jobs:
build_and_test:
runs-on: ubuntu-24.04
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Display Python version
run: python -c "import sys; print(sys.version)"

- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
. build/install_bazelisk.sh
# Load different caches depending on if this is a pull-request or merge.
# If merge (or push commit), use a read-write cache based on the python
# version, branch, and commit-sha.
# If pull-request, use a read-only cache based on the target python
# version, branch, and PR base sha.
- if: github.event_name != 'pull_request'
name: Mount bazel cache (main)
uses: actions/cache@v4
with:
path: "/home/runner/.cache/bazel"
key: bazel-${{ matrix.python-version }}-${{ github.ref_name }}-${{ github.sha }}
restore-keys: |
bazel-${{ matrix.python-version }}-${{ github.ref_name }}
bazel-${{ matrix.python-version }}-
bazel-
- if: github.event_name == 'pull_request'
name: Mount bazel cache (pull-request)
uses: actions/cache/restore@v4
with:
path: "/home/runner/.cache/bazel"
key: bazel-${{ matrix.python-version }}-${{ github.base_ref }}-${{ github.event.pull_request.base.sha }}
restore-keys: |
bazel-${{ matrix.python-version }}-${{ github.base_ref }}
bazel-${{ matrix.python-version }}-
bazel-
- name: Build all targets
run: |
export HERMETIC_PYTHON_VERSION=${{ matrix.python-version }}
bazel build //...
- name: Build pip wheel
run: |
bazel run //build:build_pip_package -- $PWD
- name: Run CPU tests
run: |
bazel test --config=cpu --test_output=errors --keep_going //...
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ poetry.lock

# PyCharm
.idea

# Bazel
/bazel-*

# Built wheels.
/*.whl
2 changes: 1 addition & 1 deletion build/requirements.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Library.
absl-py
flax
flax @ https://github.com/google/flax/archive/e2134af.zip
numpy
dm-tree
# Pre-release of JAX required for SparseCore TPUs.
Expand Down
10 changes: 7 additions & 3 deletions build/requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
# clu
# optax
# orbax-checkpoint
flax==0.10.1 \
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
flax @ https://github.com/google/flax/archive/e2134af.zip \
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
# via
# -r build/requirements.in
# clu
Expand Down Expand Up @@ -463,6 +462,7 @@ numpy==2.1.3 \
# orbax-checkpoint
# scipy
# tensorstore
# treescope
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
Expand Down Expand Up @@ -641,6 +641,10 @@ toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
treescope==0.1.7 \
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
# via flax
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
Expand Down
10 changes: 7 additions & 3 deletions build/requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
# clu
# optax
# orbax-checkpoint
flax==0.10.1 \
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
flax @ https://github.com/google/flax/archive/e2134af.zip \
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
# via
# -r build/requirements.in
# clu
Expand Down Expand Up @@ -464,6 +463,7 @@ numpy==2.1.3 \
# orbax-checkpoint
# scipy
# tensorstore
# treescope
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
Expand Down Expand Up @@ -642,6 +642,10 @@ toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
treescope==0.1.7 \
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
# via flax
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
Expand Down
10 changes: 7 additions & 3 deletions build/requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
# clu
# optax
# orbax-checkpoint
flax==0.10.1 \
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
flax @ https://github.com/google/flax/archive/e2134af.zip \
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
# via
# -r build/requirements.in
# clu
Expand Down Expand Up @@ -464,6 +463,7 @@ numpy==2.1.3 \
# orbax-checkpoint
# scipy
# tensorstore
# treescope
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
Expand Down Expand Up @@ -642,6 +642,10 @@ toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
treescope==0.1.7 \
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
# via flax
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
Expand Down
10 changes: 7 additions & 3 deletions build/requirements_lock_3_13.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
# clu
# optax
# orbax-checkpoint
flax==0.10.1 \
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
flax @ https://github.com/google/flax/archive/e2134af.zip \
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
# via
# -r build/requirements.in
# clu
Expand Down Expand Up @@ -464,6 +463,7 @@ numpy==2.1.3 \
# orbax-checkpoint
# scipy
# tensorstore
# treescope
opt-einsum==3.4.0 \
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
Expand Down Expand Up @@ -642,6 +642,10 @@ toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
# via chex
treescope==0.1.7 \
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
# via flax
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
Expand Down
56 changes: 56 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/auto_pipelining/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 The JAX SC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//jax_tpu_embedding/sparsecore:__subpackages__",
],
)

pytype_strict_library(
name = "utils",
srcs = ["utils.py"],
deps = [pypi_requirement("jax")],
)

pytype_strict_library(
name = "decompose",
srcs = ["decompose.py"],
deps = [
":preprocess",
":utils",
pypi_requirement("jax"),
],
)

pytype_strict_library(
name = "preprocess",
srcs = ["preprocess.py"],
deps = [
":utils",
pypi_requirement("jax"),
],
)

pytype_strict_library(
name = "auto_pipelining",
srcs = ["auto_pipelining.py"],
deps = [
":decompose",
pypi_requirement("jax"),
],
)
Loading

0 comments on commit 8fd28a4

Please sign in to comment.