This repo contains a proof-of-concept implementation for our paper. Please cite the conference version.
@inproceedings{DBLP:conf/ndss/LuHGLLRHWC25,
author = {Wen{-}jie Lu and
Zhicong Huang and
Zhen Gu and
Jingyu Li and
Jian Liu and
Cheng Hong and
Kui Ren and
Tao Wei and
Wenguang Chen},
title = {{BumbleBee: Secure Two-party Inference Framework for Large Transformers}},
booktitle = {32nd Annual Network and Distributed System Security Symposium, {NDSS} 2025},
publisher = {The Internet Society},
year = {February 23-28, 2025},
}
The codes are still under heavy developments, and should not be used in any security sensitive product.
Our implementations are built on top of the SPU library, specifically on this commit. Particularly, we have made the following changes.
- Adding a dispatch from DotGeneral to Batched MatMul. commit.
- Adding an intrinsic dispatch to the activation functions. commit
We prefer a Linux build. The following build has been tested on Ubuntu 22.04 with gcc-11.
# set TARGETPLATFORM='linux/arm64' if ARM CPU is used.
# Update dependencies
apt update \
&& apt upgrade -y \
&& apt install -y gcc-11 g++-11 libasan6 \
git wget curl unzip autoconf make lld-15 \
cmake ninja-build vim-common libgl1 libglib2.0-0 \
&& apt clean \
&& update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 100 \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 \
&& update-alternatives --install /usr/bin/ld.lld ld.lld /usr/bin/ld.lld-15 100
# clang is required on arm64 platform
if [ "$TARGETPLATFORM" = "linux/arm64" ] ; then apt install -y clang-15 \
&& apt clean \
&& update-alternatives --install /usr/bin/clang clang /usr/bin/clang-15 100 \
&& update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-15 100 \
; fi
# amd64 is only reqiured on amd64 platform
if [ "$TARGETPLATFORM" = "linux/amd64" ] ; then apt install -y nasm ; fi
# install conda
if [ "$TARGETPLATFORM" = "linux/arm64" ] ; then CONDA_ARCH=aarch64 ; else CONDA_ARCH=x86_64 ; fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-py310_24.3.0-0-Linux-$CONDA_ARCH.sh \
&& bash Miniconda3-py310_24.3.0-0-Linux-$CONDA_ARCH.sh -b \
&& rm -f Miniconda3-py310_24.3.0-0-Linux-$CONDA_ARCH.sh \
&& /root/miniconda3/bin/conda init
# Add conda to path
export PATH="/root/miniconda3/bin:${PATH}"
# install bazel
if [ "$TARGETPLATFORM" = "linux/arm64" ] ; then BAZEL_ARCH=arm64 ; else BAZEL_ARCH=amd64 ; fi \
&& wget https://github.com/bazelbuild/bazelisk/releases/download/v1.20.0/bazelisk-linux-$BAZEL_ARCH \
&& mv bazelisk-linux-$BAZEL_ARCH /usr/bin/bazel \
&& chmod +x /usr/bin/bazel
# install python dependencies
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements-dev.txt
About the commands used to install on other platform, you can follow Ubuntu docker file.
Why? In SPU, we apply a Python hijack like this to replace the entire activation calls with our MPC protocols. This hijack provides a simpler alternative to IR rewriting. Pattern matching all IR operations for complex activation functions, such as GeLU, is a tedious task. For the softmax function, we can easily apply this Python hijack. However, for the GeLU/SILU activations, we unfortunately need to modify the HuggingFace source code. This is because the activation calls used by HuggingFace are not Python calls but C pointer functions.
# transformers/models/gpt2/modeling_flax_gpt2.py#297
def __call__(self, hidden_states, deterministic: bool = True):
hidden_states = self.c_fc(hidden_states)
# hidden_states = self.act(hidden_states)
hidden_states = jax.nn.gelu(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
# transformers/models/bert/modeling_flax_bert.py#468
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
#hidden_states = self.activation(hidden_states)
hidden_states = jax.nn.gelu(hidden_states)
return hidden_states
# transformers/models/vit/modeling_flax_vit.py#282
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
#hidden_states = self.activation(hidden_states)
hidden_states = jax.nn.gelu(hidden_states)
return hidden_states
Also, we can skip some computation in JAX's argmax function. This can reduce some cost in the one-hot encoding (e.g., embedding lookup).
#jax/_src/lax/lax.py#4118
def __call__(self, op_val_index, acc_val_index):
op_val, op_index = op_val_index
acc_val, acc_index = acc_val_index
pick_op_val = self._value_comparator(op_val, acc_val)
return (select(pick_op_val, op_val, acc_val),
select(pick_op_val, op_index, acc_index))
# Pick op_val if Lt (for argmin) or if NaN
#pick_op_val = bitwise_or(self._value_comparator(op_val, acc_val),
# ne(op_val, op_val))
# If x and y are not NaN and x = y, then pick the first
#pick_op_index = bitwise_or(pick_op_val,
# bitwise_and(eq(op_val, acc_val),
# lt(op_index, acc_index)))
#return (select(pick_op_val, op_val, acc_val),
# select(pick_op_index, op_index, acc_index))
bazel build -c opt examples/python/ml/flax_gpt2/...
bazel build -c opt examples/python/ml/flax_bert/...
bazel build -c opt examples/python/ml/flax_vit/...
It might take some times to fetch the dependencies.
INFO: Elapsed time: 403.615s, Critical Path: 277.54s
INFO: 3686 processes: 223 internal, 3463 linux-sandbox.
INFO: Build completed successfully, 3686 total actions
-
bazel run -c opt examples/python/microbench:gelu
-
bazel run -c opt examples/python/microbench:softmax
-
bazel run -c opt examples/python/microbench:matmul
NOTE We need to fetch models from HuggingFace. Please make sure your network condition :).
-
Install datasets
pip install datasets
-
Launch SPU backend runtime
env SPU_BB_SET_IEQUAL_BITS=15 bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/conf/2pc.json up
-
Run
flax_bert
exampleenv SPU_BB_SET_IEQUAL_BITS=15 bazel run -c opt //examples/python/ml/flax_bert -- --config `pwd`/examples/python/conf/2pc.json
We prefer to set the env variable SPU_BB_SET_IEQUAL_BITS
equals to the vocabulary size. For instance, in BERT-base, we have about 30522 vocabulary.
Thus SPU_BB_SET_IEQUAL_BITS=15
would be enough for the one-hot computation.
-
Launch SPU backend runtime
env SPU_BB_SET_IEQUAL_BITS=16 bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/conf/2pc.json up
-
Run
flax_gpt2
exampleenv SPU_BB_SET_IEQUAL_BITS=16 bazel run -c opt //examples/python/ml/flax_gpt2 -- --config `pwd`/examples/python/conf/2pc.json
The vocabulary size in GPT2 is about 50k. Thus we set SPU_BB_SET_IEQUAL_BITS=16
for the one-hot computation.
-
Install packages
pip install datasets
-
Launch SPU backend runtime
bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/conf/2pc.json up
-
Run
flax_vit
examplebazel run -c opt //examples/python/ml/flax_vit/flax_vit -- --config `pwd`/examples/python/conf/2pc.json