diff --git a/changes/181.feature b/changes/181.feature new file mode 100644 index 00000000..7e3d96a3 --- /dev/null +++ b/changes/181.feature @@ -0,0 +1 @@ +Display kernel-pull-progress from background-task-reporter diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index a5f524cc..d15dec19 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -298,6 +298,28 @@ async def get_or_create( rqst.set_json(params) async with rqst.fetch() as resp: data = await resp.json() + with tqdm(total=100, unit='%') as pbar: + task_id = data['background_task'] + bgtask = resp.session.BackgroundTask(task_id) + async with bgtask.listen_events() as response: + async for ev in response: + _data = json.loads(ev.data) + if ev.event == 'bgtask_updated': + current = _data['current_progress'] + total = _data['total_progress'] + if total==0: + total = 1e-2 + pbar.n = round(current / total * 100, 2) + pbar.update(0) + pbar.refresh() + elif ev.event == 'bgtask_done': + pbar.n = 100.0 + pbar.update(0) + pbar.refresh() + pbar.clear() + async with rqst.fetch() as resp: + data = await resp.json() + print('kernel pulling done...') o = cls(name, owner_access_key) # type: ignore if api_session.get().api_version[0] >= 5: o.id = UUID(data['sessionId']) @@ -306,6 +328,7 @@ async def get_or_create( o.service_ports = data.get('servicePorts', []) o.domain = domain_name o.group = group_name + return o @api_function