Skip to content

Commit

Permalink
Added CI tests for RabbitMQ components in AMS (#89)
Browse files Browse the repository at this point in the history
* Currently it skips RMQ tests on LC systems.
Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier authored Oct 21, 2024
1 parent f588e87 commit aa07df4
Show file tree
Hide file tree
Showing 13 changed files with 586 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/containers/x86_64-broadwell-cuda11.6.1/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FROM nvidia/cuda:11.6.1-devel-ubi8 AS base
MAINTAINER Giorgis Georgakoudis <[email protected]>
RUN \
yum install -y dnf &&\
dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file &&\
dnf install -y git xz autoconf automake unzip patch gcc-gfortran bzip2 file libevent-devel openssl-devel &&\
dnf upgrade -y &&\
dnf clean all
COPY repo repo
Expand Down
2 changes: 1 addition & 1 deletion .github/containers/x86_64-broadwell-gcc11.2.1/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MAINTAINER Giorgis Georgakoudis <[email protected]>
RUN \
yum install -y dnf &&\
dnf group install -y "Development Tools" &&\
dnf install -y git gcc-toolset-11 environment-modules &&\
dnf install -y git gcc-toolset-11 environment-modules libevent-devel openssl-devel &&\
dnf upgrade -y
COPY repo repo
RUN \
Expand Down
99 changes: 99 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
branches: [ "develop" ]
workflow_dispatch:


jobs:
build-run-tests:
# The type of runner that the job will run on
Expand Down Expand Up @@ -389,3 +390,101 @@ jobs:
-DWITH_ADIAK=Off \
$GITHUB_WORKSPACE
make
build-rmq-tests:
# The type of runner that the job will run on
runs-on: ubuntu-latest
services:
rabbitmq:
image: rabbitmq:3.11
env:
RABBITMQ_DEFAULT_USER: ams
RABBITMQ_DEFAULT_PASS: ams
ports:
- 5672

container:
image: ghcr.io/llnl/ams-ci-almalinux8:latest
env:
RABBITMQ_USER: ams
RABBITMQ_PASS: ams
RABBITMQ_HOST: rabbitmq
RABBITMQ_PORT: 5672

steps:
- uses: actions/checkout@v4
- name: Build Torch=On FAISS=On RMQ=On AMS
shell: bash -l {0}
run: |
module load gcc/11.2.1
export SPACK_ROOT=/spack/
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
rm -rf build/
mkdir build
cd build
export AMS_MFEM_PATH=$(spack location -i mfem)
export AMS_TORCH_PATH=$(spack location -i py-torch)/lib/python3.10/site-packages/torch/share/cmake/Torch
export AMS_FAISS_PATH=$(spack location -i faiss)
export AMS_UMPIRE_PATH=$(spack location -i umpire)
export AMS_HDF5_PATH=$(spack location -i hdf5)
export AMS_AMQPCPP_PATH=$(spack location -i amqp-cpp)/cmake
cmake \
-DBUILD_SHARED_LIBS=On \
-DCMAKE_PREFIX_PATH=$INSTALL_DIR \
-DWITH_CALIPER=On \
-DWITH_HDF5=On \
-DWITH_EXAMPLES=On \
-DAMS_HDF5_DIR=$AMS_HDF5_PATH \
-DCMAKE_INSTALL_PREFIX=./install \
-DCMAKE_BUILD_TYPE=Release \
-DWITH_CUDA=Off \
-DUMPIRE_DIR=$AMS_UMPIRE_PATH \
-DMFEM_DIR=$AMS_MFEM_PATH \
-DWITH_FAISS=On \
-DWITH_MPI=On \
-DWITH_TORCH=On \
-DWITH_TESTS=On \
-DTorch_DIR=$AMS_TORCH_PATH \
-DFAISS_DIR=$AMS_FAISS_PATH \
-DWITH_AMS_DEBUG=On \
-DWITH_WORKFLOW=On \
-DWITH_ADIAK=Off \
-DWITH_RMQ=On \
-Damqpcpp_DIR=$AMS_AMQPCPP_PATH \
$GITHUB_WORKSPACE
make
- name: Run tests Torch=On FAISS=On RMQ=On AMSlib RabbitMQ egress
run: |
cd build
export SPACK_ROOT=/spack/
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
# We overwrite the rmq.json created by CMake
echo """{
\"db\": {
\"dbType\": \"rmq\",
\"rmq_config\": {
\"rabbitmq-name\": \"rabbit\",
\"rabbitmq-user\": \"${RABBITMQ_USER}\",
\"rabbitmq-password\": \"${RABBITMQ_PASS}\",
\"service-port\": ${RABBITMQ_PORT},
\"service-host\": \"${RABBITMQ_HOST}\",
\"rabbitmq-vhost\": \"/\",
\"rabbitmq-outbound-queue\": \"test-ci\",
\"rabbitmq-exchange\": \"ams-fanout\",
\"rabbitmq-routing-key\": \"training\"
},
\"update_surrogate\": false
},
\"ml_models\": {},
\"domain_models\": {}
}""" > $GITHUB_WORKSPACE/build/tests/AMSlib/rmq.json
make test
env:
RABBITMQ_USER: ams
RABBITMQ_PASS: ams
RABBITMQ_HOST: rabbitmq
RABBITMQ_PORT: 5672
97 changes: 76 additions & 21 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import json
import pika


class AMSMessage(object):
"""
Represents a RabbitMQ incoming message from AMSLib.
Expand All @@ -28,6 +27,24 @@ class AMSMessage(object):
def __init__(self, body: str):
self.body = body

self.num_elements = None
self.hsize = None
self.dtype_byte = None
self.mpi_rank = None
self.domain_name_size = None
self.domain_names = []
self.input_dim = None
self.output_dim = None

def __str__(self):
dt = "float" if self.dtype_byte == 4 else 8
if not self.dtype_byte:
dt = None
return f"AMSMessage(domain={self.domain_names}, #mpi={self.mpi_rank}, num_elements={self.num_elements}, datatype={dt}, input_dim={self.input_dim}, output_dim={self.output_dim})"

def __repr__(self):
return self.__str__()

def header_format(self) -> str:
"""
This string represents the AMS format in Python pack format:
Expand Down Expand Up @@ -110,6 +127,15 @@ def _parse_header(self, body: str) -> dict:
res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"]))
res["msg_size"] = hsize + res["dsize"]
res["multiple_msg"] = len(body) != res["msg_size"]

self.num_elements = int(res["num_element"])
self.hsize = int(res["hsize"])
self.dtype_byte = int(res["datatype"])
self.mpi_rank = int(res["mpirank"])
self.domain_name_size = int(res["domain_size"])
self.input_dim = int(res["input_dim"])
self.output_dim = int(res["output_dim"])

return res

def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.array]:
Expand Down Expand Up @@ -144,30 +170,37 @@ def _decode(self, body: str) -> Tuple[np.array]:
input = []
output = []
# Multiple AMS messages could be packed in one RMQ message
# TODO: we should manage potential mutliple messages per AMSMessage better
while body:
header_info = self._parse_header(body)
domain_name, temp_input, temp_output = self._parse_data(body, header_info)
# print(f"MSG: {domain_name} input shape {temp_input.shape} outpute shape {temp_output.shape}")
# total size of byte we read for that message
chunk_size = header_info["hsize"] + header_info["dsize"] + header_info["domain_size"]
input.append(temp_input)
output.append(temp_output)
# We remove the current message and keep going
body = body[chunk_size:]
self.domain_names.append(domain_name)
return domain_name, np.concatenate(input), np.concatenate(output)

def decode(self) -> Tuple[str, np.array, np.array]:
return self._decode(self.body)

def default_ams_callback(method, properties, body):
"""Simple callback that decode incoming message assuming they are AMS binary messages"""
return AMSMessage(body)

class AMSChannel:
"""
A wrapper around Pika RabbitMQ channel
"""

def __init__(self, connection, q_name, logger: logging.Logger = None):
def __init__(self, connection, q_name, callback: Optional[Callable] = None, logger: Optional[logging.Logger] = None):
self.connection = connection
self.q_name = q_name
self.logger = logger if logger else logging.getLogger(__name__)
self.callback = callback if callback else self.default_callback

def __enter__(self):
self.open()
Expand All @@ -176,9 +209,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@staticmethod
def callback(method, properties, body):
return body.decode("utf-8")
def default_callback(self, method, properties, body):
""" Simple callback that return the message received"""
return body

def open(self):
self.channel = self.connection.channel()
Expand All @@ -187,18 +220,19 @@ def open(self):
def close(self):
self.channel.close()

def receive(self, n_msg: int = None, accum_msg=list()):
def receive(self, n_msg: int = None, timeout: int = None, accum_msg = list()):
"""
Consume a message on the queue and post processing by calling the callback.
@param n_msg The number of messages to receive.
- if n_msg is None, this call will block for ever and will process all messages that arrives
- if n_msg = 1 for example, this function will block until one message has been processed.
@param timeout If None, timout infinite, otherwise timeout in seconds
@return a list containing all received messages
"""

if self.channel and self.channel.is_open:
self.logger.info(
f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..."
f"Starting to consume messages from queue={self.q_name} ..."
)
# we will consume only n_msg and requeue all other messages
# if there are more messages in the queue.
Expand All @@ -207,11 +241,15 @@ def receive(self, n_msg: int = None, accum_msg=list()):
n_msg = max(n_msg, 0)
message_consumed = 0
# Comsume n_msg messages and break out
for method_frame, properties, body in self.channel.consume(self.q_name):
for method_frame, properties, body in self.channel.consume(self.q_name, inactivity_timeout=timeout):
if (method_frame, properties, body) == (None, None, None):
self.logger.info(f"Timeout after {timeout} seconds")
self.channel.cancel()
break
# Call the call on the message parts
try:
accum_msg.append(
BlockingClient.callback(
self.callback(
method_frame,
properties,
body,
Expand All @@ -223,23 +261,24 @@ def receive(self, n_msg: int = None, accum_msg=list()):
finally:
# Acknowledge the message even on failure
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
message_consumed += 1
self.logger.warning(
f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})"
f"Consumed message {message_consumed}/{method_frame.delivery_tag} (exchange=\'{method_frame.exchange}\', routing_key={method_frame.routing_key})"
)
message_consumed += 1
# Escape out of the loop after nb_msg messages
if message_consumed == n_msg:
# Cancel the consumer and return any pending messages
self.channel.cancel()
break
return accum_msg

def send(self, text: str):
def send(self, text: str, exchange : str = ""):
"""
Send a message
@param text The text to send
@param exchange Exchange to use
"""
self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text)
self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text)
return

def get_messages(self):
Expand All @@ -250,26 +289,42 @@ def purge(self):
if self.channel and self.channel.is_open:
self.channel.queue_purge(self.q_name)


class BlockingClient:
"""
BlockingClient is a class that manages a simple blocking RMQ client lifecycle.
"""

def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None):
def __init__(
self,
host: str,
port: int,
vhost: str,
user: str,
password: str,
cert: Optional[str] = None,
callback: Optional[Callable] = None,
logger: Optional[logging.Logger] = None
):
# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file):
# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt
self.logger = logger if logger else logging.getLogger(__name__)
self.cert = cert
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)

if self.cert is None or self.cert == "":
ssl_options = None
else:
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)
ssl_options = pika.SSLOptions(self.context)

self.host = host
self.vhost = vhost
self.port = port
self.user = user
self.password = password
self.callback = callback

self.credentials = pika.PlainCredentials(self.user, self.password)

Expand All @@ -278,7 +333,7 @@ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logg
port=self.port,
virtual_host=self.vhost,
credentials=self.credentials,
ssl_options=pika.SSLOptions(self.context),
ssl_options=ssl_options,
)

def __enter__(self):
Expand All @@ -290,7 +345,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def connect(self, queue):
"""Connect to the queue"""
return AMSChannel(self.connection, queue)
return AMSChannel(self.connection, queue, self.callback)


class AsyncConsumer(object):
Expand Down
Loading

0 comments on commit aa07df4

Please sign in to comment.