Skip to content
This repository has been archived by the owner on Sep 22, 2023. It is now read-only.

feat: Display kernel-pull-progress #181

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/181.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Display the progress of pulling kernel from reporter of the background task.
42 changes: 40 additions & 2 deletions src/ai/backend/client/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Sequence,
Tuple,
)
from tqdm import tqdm

import aiohttp
import click
Expand Down Expand Up @@ -599,10 +600,45 @@ async def _run(session, idx, name, envs,
except Exception as e:
print_fail('[{0}] {1}'.format(idx, e))
return

async def display_kernel_pulling(compute_session: AsyncSession.ComputeSession) -> bool:
try:
bgtask = compute_session.backgroundtask
except Exception as e:
print_error(e)
return False
else:
with tqdm(total=100, unit='%') as pbar:
async with bgtask.listen_events() as response:
async for ev in response:
progress = json.loads(ev.data)
if ev.event == 'bgtask_updated':
current = progress['current_progress']
total = progress['total_progress']
if total == 0:
pbar.n = 0
else:
pbar.n = round(current / total * 100, 2)
pbar.update(0)
pbar.refresh()
elif ev.event == 'bgtask_done':
pbar.n = 100
pbar.update(0)
pbar.refresh()
pbar.clear()
compute_session = await session.ComputeSession.get_or_create(
image,
name=name,
)
await asyncio.sleep(0.1)
return True

if compute_session.status == 'PENDING':
print_info('Session ID {0} is enqueued for scheduling.'
.format(name))
return
result = await display_kernel_pulling(compute_session)
if not result:
return
elif compute_session.status == 'SCHEDULED':
print_info('Session ID {0} is scheduled and about to be started.'
.format(name))
Expand All @@ -623,7 +659,9 @@ async def _run(session, idx, name, envs,
elif compute_session.status == 'TIMEOUT':
print_info('Session ID {0} is still on the job queue.'
.format(name))
return
result = await display_kernel_pulling(compute_session)
if not result:
return
elif compute_session.status in ('ERROR', 'CANCELLED'):
print_fail('Session ID {0} has an error during scheduling/startup or cancelled.'
.format(name))
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ async def get_or_create(
o.service_ports = data.get('servicePorts', [])
o.domain = domain_name
o.group = group_name
if 'background_task' in data:
task_id = data['background_task']
o.backgroundtask = resp.session.BackgroundTask(task_id)
return o

@api_function
Expand Down