Skip to content

Commit

Permalink
Initial implementation of unary-stream in web_aiohttp transport and a…
Browse files Browse the repository at this point in the history
…iohttp channel
  • Loading branch information
Aksem committed Nov 3, 2024
1 parent ffce63f commit 9e8e317
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 119 deletions.
138 changes: 120 additions & 18 deletions modapp/channels/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
from typing import Any, AsyncIterator, Dict, Optional, Type, TypeVar
import asyncio
import json
from typing import Any, AsyncIterator, Type, TypeVar

import aiohttp
from loguru import logger
from typing_extensions import override

from modapp.base_converter import BaseConverter
from modapp.base_model import BaseModel
from modapp.client import BaseChannel

T = TypeVar("T", bound=BaseModel)
StreamClosedMessage = object()


class AioHttpChannel(BaseChannel):
"""
NOTE: aiohttp conflicts with web_socketify, requests cannot be sent in web_socketify transport
"""

def __init__(self, converter: BaseConverter, server_address: str) -> None:
super().__init__(converter)
self.server_address = server_address
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._ws_connection_id: str | None = None
self._msg_queue_by_stream_id: dict[str, asyncio.Queue] = {}
self._ws_message_processing_task: asyncio.Task | None = None

@override
async def send_unary_unary(
self,
route_path: str,
request_data: BaseModel,
request: BaseModel,
reply_cls: Type[T],
meta: Optional[Dict[str, Any]] = None,
meta: dict[str, Any] | None = None,
timeout: float | None = 5,
) -> T:
raw_data = self.converter.model_to_raw(request_data)
raw_data = self.converter.model_to_raw(request)

# TODO: save client session
async with aiohttp.ClientSession() as session:
# TODO: check route path
async with session.post(
self.server_address + route_path.replace(".", "/").lower(),
data=raw_data,
timeout=aiohttp.ClientTimeout(total=timeout)
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
raw_reply = await response.read()

Expand All @@ -48,20 +57,38 @@ async def send_unary_unary(
async def send_unary_stream(
self,
route_path: str,
request_data: BaseModel,
request: BaseModel,
reply_cls: Type[T],
meta: Optional[Dict[str, Any]] = None,
meta: dict[str, Any] | None = None,
) -> AsyncIterator[T]:
raise NotImplementedError()
# raw_data = self.converter.model_to_raw(request_data)
if self._ws is None:
await self._connect_to_ws()
assert self._ws is not None
assert self._ws_connection_id is not None

# Send HTTP request to start stream
raw_data = self.converter.model_to_raw(request)
async with aiohttp.ClientSession() as session:
async with session.post(
self.server_address + route_path.replace(".", "/").lower(),
data=raw_data,
timeout=aiohttp.ClientTimeout(),
headers={"Connection-Id": self._ws_connection_id},
) as response:
stream_id = response.headers.get("Stream-Id")

# TODO
if stream_id is None:
raise Exception() # TODO

# assert isinstance(
# reply_iterator, AsyncIterator
# ), "Reply on unary-stream request should be async iterator of bytes"
# async for raw_message in reply_iterator:
# yield self.converter.raw_to_model(raw_message, reply_cls)
stream_queue = asyncio.Queue()
self._msg_queue_by_stream_id[stream_id] = stream_queue
while True:
raw_message = await stream_queue.get()
if raw_message == StreamClosedMessage:
del self._msg_queue_by_stream_id[stream_id]
break
message = self.converter.raw_to_model(raw_message, reply_cls)
yield message

@override
async def send_stream_unary(self) -> None:
Expand All @@ -72,5 +99,80 @@ async def send_stream_stream(self) -> None:
raise NotImplementedError()

@override
def __aexit__(self) -> None:
pass
async def __aexit__(self, exc_type, exc, tb) -> None:
if self._session is not None:
await self._session.__aexit__(exc_type, exc, tb)
self._session = None
if self._ws is not None:
await self._ws.__aexit__(exc_type, exc, tb)
self._ws = None

async def _connect_to_ws(self):
if self._session is not None and self._ws is not None:
logger.debug("Already connected to websocket")
return

# TODO: handle timeout
self._session = await aiohttp.ClientSession().__aenter__()
self._ws = await self._session.ws_connect(
f"{self.server_address}/ws"
).__aenter__()

# TODO: catch timeout
first_message = await asyncio.wait_for(anext(self._ws), timeout=10)
if first_message.type == aiohttp.WSMsgType.ERROR:
# error happened, raise exception
raise Exception() # TODO
elif first_message.type == aiohttp.WSMsgType.TEXT:
first_message_json = json.loads(first_message.data)
try:
# TODO: validate type
self._ws_connection_id = first_message_json["connectionId"]
except KeyError:
raise Exception() # TODO
else:
raise Exception() # TODO

self._ws_message_processing_task = asyncio.create_task(
self.process_ws_messages()
)

async def process_ws_messages(self):
assert self._ws is not None
async for msg in self._ws:
msg_json = json.loads(msg.data)
try:
stream_id = msg_json["streamId"]
except KeyError:
logger.error("No streamId in ws message, skip it")
continue

try:
stream_queue = self._msg_queue_by_stream_id[stream_id]
except KeyError:
logger.error(f"No queue found for streamId {stream_id}, skip it")
continue

try:
stream_msg = msg_json["message"]
except KeyError:
stream_end_msg = msg_json.get("end", None)
if stream_end_msg is not None:
if stream_end_msg is True:
await stream_queue.put(StreamClosedMessage)
else:
logger.error(
f"Field 'end' has unsupported value '{stream_end_msg}', only 'true' is supported"
)
continue
else:
logger.error(
"Neither 'message' field nor 'end' field in ws message, skip it"
)
continue

if msg.type == aiohttp.WSMsgType.TEXT:
await stream_queue.put(stream_msg)
elif msg.type == aiohttp.WSMsgType.ERROR:
await stream_queue.put(StreamClosedMessage)
break
Loading

0 comments on commit 9e8e317

Please sign in to comment.