Skip to content

Commit

Permalink
p2p backend: heart beat (#223)
Browse files Browse the repository at this point in the history
python grpc server doesn't have an easy way to check whether a
client is dropped or not. This makes it hard to manage endpoint in
the flame channel concept and hence can cause deadlock situations
where an endpoint (working as grpc server) waits for data to arrive
from an endpoint which is dropped.

As a workaround, a heart beat is sent periodically; if it is not
received for a certain duration, it is assumed that the client is
dropped. The grpc server cleans up resources allocated for the
endpoint, which prevents deadlock.
  • Loading branch information
myungjin authored Sep 1, 2022
1 parent 9af0480 commit c6cc492
Showing 1 changed file with 104 additions and 36 deletions.
140 changes: 104 additions & 36 deletions lib/python/flame/backend/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import logging
import socket
import time
from typing import AsyncIterable, Iterable, Tuple

import grpc
Expand All @@ -34,7 +35,9 @@
logger = logging.getLogger(__name__)

ENDPOINT_TOKEN_LEN = 2
HEART_BEAT_DURATION = 30
HEART_BEAT_DURATION = 30 # for metaserver
QUEUE_WAIT_TIME = 10 # 10 second
EXTRA_WAIT_TIME = QUEUE_WAIT_TIME / 2


class BackendServicer(msg_pb2_grpc.BackendRouteServicer):
Expand All @@ -53,11 +56,12 @@ async def send_data(self, req_iter: AsyncIterable[msg_pb2.Data],
unused_context) -> msg_pb2.BackendID:
"""Implement a method to handle send_data request stream."""
# From server perspective, the server receives data from client.
logger.warn(f"address of req_iter = {hex(id(req_iter))}")
logger.warn(f"address of context = {hex(id(unused_context))}")

async for msg in req_iter:
await self.p2pbe._handle_data(msg)
self.p2pbe._set_heart_beat(msg.end_id)
# if the message is not a heart beat message,
# the message needs to be processed.
if msg.seqno != -1 or msg.eom is False or msg.channel_name != "":
await self.p2pbe._handle_data(msg)

return msg_pb2.BackendID(end_id=self.p2pbe._id)

Expand All @@ -67,6 +71,7 @@ async def recv_data(self, req: msg_pb2.BackendID,
# From server perspective, the server sends data to client.
dck_task = asyncio.create_task(self._dummy_context_keeper(context))
self.p2pbe._set_writer(req.end_id, context)
self.p2pbe._set_heart_beat(req.end_id)

await dck_task

Expand Down Expand Up @@ -101,18 +106,15 @@ def __init__(self):
self._backend = None

self._endpoints = {}
self.end_to_rwop = {}
self._channels = {}
self._livecheck = {}

with background_thread_loop() as loop:
self._loop = loop

async def _init_loop_stuff():
self._eventq = asyncio.Queue()

coro = self._monitor_end_termination()
_ = asyncio.create_task(coro)

coro = self._setup_server()
_ = asyncio.create_task(coro)

Expand All @@ -124,20 +126,6 @@ async def _init_loop_stuff():

self._initialized = True

async def _monitor_end_termination(self):
# TODO: handle how to monitor grpc channel status
# while True:
# for end_id, (reader, _) in list(self._endpoints.items()):
# if not reader.at_eof():
# continue

# # connection is closed
# await self._eventq.put((BackendEvent.DISCONNECT, end_id))
# await self._close(end_id)

# await asyncio.sleep(1)
pass

async def _setup_server(self):
server = grpc.aio.server()
msg_pb2_grpc.add_BackendRouteServicer_to_server(
Expand Down Expand Up @@ -222,7 +210,7 @@ async def _register_channel(self, channel) -> None:
raise SystemError('registration failure')

for endpoint in meta_resp.endpoints:
logger.info(f"endpoint: {endpoint}")
logger.debug(f"connecting to endpoint: {endpoint}")
await self._connect_and_notify(endpoint, channel.name())

while True:
Expand All @@ -249,7 +237,7 @@ async def notify(self, channel_name, notify_type, stub, grpc_ch) -> bool:
try:
resp = await stub.notify_end(msg)
except grpc.aio.AioRpcError:
logger.warn("can't proceed as grpc channel is unavailable")
logger.debug("can't proceed as grpc channel is unavailable")
return False

logger.debug(f"resp = {resp}")
Expand Down Expand Up @@ -370,33 +358,56 @@ async def _broadcast_task(self, channel):
await self.send_chunks(end_id, channel.name(), data)
except Exception as ex:
ex_name = type(ex).__name__
logger.warn(f"An exception of type {ex_name} occurred")
logger.debug(f"An exception of type {ex_name} occurred")

await self._eventq.put((BackendEvent.DISCONNECT, end_id))
del self._endpoints[end_id]
self._cleanup_end(end_id)
txq.task_done()

async def _unicast_task(self, channel, end_id):
txq = channel.get_txq(end_id)

while True:
data = await txq.get()
try:
data = await asyncio.wait_for(txq.get(), QUEUE_WAIT_TIME)
except asyncio.TimeoutError:
if end_id not in self._endpoints:
logger.debug(f"end_id {end_id} not in _endpoints")
break

_, _, clt_writer, _ = self._endpoints[end_id]
if clt_writer is None:
continue

def heart_beat():
# the condition for heart beat message:
# channel_name = ""
# seqno = -1
# eom = True
msg = msg_pb2.Data(end_id=self._id,
channel_name="",
payload=b'',
seqno=-1,
eom=True)

yield msg

await clt_writer.send_data(heart_beat())
continue

try:
await self.send_chunks(end_id, channel.name(), data)
except Exception as ex:
ex_name = type(ex).__name__
logger.warn(f"An exception of type {ex_name} occurred")
logger.debug(f"An exception of type {ex_name} occurred")

await self._eventq.put((BackendEvent.DISCONNECT, end_id))
del self._endpoints[end_id]
self._cleanup_end(end_id)
txq.task_done()
# This break ends a tx_task for end_id
break

txq.task_done()

logger.warn(f"unicast task for {end_id} terminated")
logger.debug(f"unicast task for {end_id} terminated")

async def send_chunks(self, other: str, ch_name: str, data: bytes) -> None:
"""Send data chunks to an end."""
Expand Down Expand Up @@ -446,7 +457,7 @@ async def _rx_task(self, end_id: str, reader) -> None:
try:
msg = await reader.read()
except grpc.aio.AioRpcError:
logger.info(f"AioRpcError occurred for {end_id}")
logger.debug(f"AioRpcError occurred for {end_id}")
break

if msg == grpc.aio.EOF:
Expand All @@ -456,7 +467,64 @@ async def _rx_task(self, end_id: str, reader) -> None:

# grpc channel is unavailable
# so, clean up an entry for end_id from _endpoints dict
self._cleanup_end(end_id)

logger.debug(f"cleaned up {end_id} info from _endpoints")

async def _cleanup_end(self, end_id):
await self._eventq.put((BackendEvent.DISCONNECT, end_id))
del self._endpoints[end_id]
if end_id in self._endpoints:
del self._endpoints[end_id]
if end_id in self._livecheck:
self._livecheck[end_id].cancel()
del self._livecheck[end_id]

def _set_heart_beat(self, end_id) -> None:
logger.debug(f"heart beat data message for {end_id}")
if end_id not in self._livecheck:
timeout = QUEUE_WAIT_TIME + 5
self._livecheck[end_id] = LiveChecker(self, end_id, timeout)

self._livecheck[end_id].reset()


class LiveChecker:
"""LiveChecker class."""

def __init__(self, p2pbe, end_id, timeout) -> None:
"""Initialize an instance."""
self._p2pbe = p2pbe
self._end_id = end_id
self._timeout = timeout

self._task = None
self._last_reset = time.time()

async def _check(self):
await asyncio.sleep(self._timeout)
await self._p2pbe._cleanup_end(self._end_id)
logger.debug(f"live check timeout occured for {self._end_id}")

def cancel(self) -> None:
"""Cancel a task."""
if self._task is None or self._task.cancelled():
return

self._task.cancel()
logger.debug(f"cancelled task for {self._end_id}")

def reset(self) -> None:
"""Reset a task."""
now = time.time()
if now - self._last_reset < EXTRA_WAIT_TIME / 2:
# this is to prevent too frequent reset
logger.debug("this is to prevent too frequent reset")
return

self._last_reset = now

self.cancel()

self._task = asyncio.ensure_future(self._check())

logger.info(f"cleaned up {end_id} info from _endpoints")
logger.debug(f"set future for {self._end_id}")

0 comments on commit c6cc492

Please sign in to comment.