diff --git a/changes/181.feature b/changes/181.feature new file mode 100644 index 00000000..b9df3c38 --- /dev/null +++ b/changes/181.feature @@ -0,0 +1 @@ +Display the progress of pulling kernel from reporter of the background task. diff --git a/src/ai/backend/client/cli/run.py b/src/ai/backend/client/cli/run.py index aa0eacbe..74a4d217 100644 --- a/src/ai/backend/client/cli/run.py +++ b/src/ai/backend/client/cli/run.py @@ -15,6 +15,7 @@ Sequence, Tuple, ) +from tqdm import tqdm import aiohttp import click @@ -599,10 +600,45 @@ async def _run(session, idx, name, envs, except Exception as e: print_fail('[{0}] {1}'.format(idx, e)) return + + async def display_kernel_pulling(compute_session: AsyncSession.ComputeSession) -> bool: + try: + bgtask = compute_session.backgroundtask + except Exception as e: + print_error(e) + return False + else: + with tqdm(total=100, unit='%') as pbar: + async with bgtask.listen_events() as response: + async for ev in response: + progress = json.loads(ev.data) + if ev.event == 'bgtask_updated': + current = progress['current_progress'] + total = progress['total_progress'] + if total == 0: + pbar.n = 0 + else: + pbar.n = round(current / total * 100, 2) + pbar.update(0) + pbar.refresh() + elif ev.event == 'bgtask_done': + pbar.n = 100 + pbar.update(0) + pbar.refresh() + pbar.clear() + compute_session = await session.ComputeSession.get_or_create( + image, + name=name, + ) + await asyncio.sleep(0.1) + return True + if compute_session.status == 'PENDING': print_info('Session ID {0} is enqueued for scheduling.' .format(name)) - return + result = await display_kernel_pulling(compute_session) + if not result: + return elif compute_session.status == 'SCHEDULED': print_info('Session ID {0} is scheduled and about to be started.' .format(name)) @@ -623,7 +659,9 @@ async def _run(session, idx, name, envs, elif compute_session.status == 'TIMEOUT': print_info('Session ID {0} is still on the job queue.' .format(name)) - return + result = await display_kernel_pulling(compute_session) + if not result: + return elif compute_session.status in ('ERROR', 'CANCELLED'): print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' .format(name)) diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index dd4a8053..d2371482 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -309,6 +309,9 @@ async def get_or_create( o.service_ports = data.get('servicePorts', []) o.domain = domain_name o.group = group_name + if 'background_task' in data: + task_id = data['background_task'] + o.backgroundtask = resp.session.BackgroundTask(task_id) return o @api_function