|
1 | 1 | import asyncio
|
2 | 2 | import datetime
|
3 | 3 | import json
|
| 4 | +import math |
4 | 5 | import pprint
|
5 | 6 | import tempfile
|
6 | 7 | import zipfile
|
7 | 8 | from typing import Awaitable, Callable, Optional
|
8 | 9 |
|
9 | 10 | import requests
|
10 |
| -from consts import AMD_REQUIREMENTS, GPU, NVIDIA_REQUIREMENTS, GitHubGPU |
| 11 | +from consts import ( |
| 12 | + AMD_REQUIREMENTS, |
| 13 | + DEFAULT_GITHUB_TIMEOUT_MINUTES, |
| 14 | + GPU, |
| 15 | + NVIDIA_REQUIREMENTS, |
| 16 | + TIMEOUT_BUFFER_MINUTES, |
| 17 | + GitHubGPU, |
| 18 | + SubmissionMode, |
| 19 | +) |
11 | 20 | from github import Github, UnknownObjectException, WorkflowRun
|
12 | 21 | from report import RunProgressReporter
|
13 | 22 | from run_eval import CompileResult, EvalResult, FullResult, RunResult, SystemInfo
|
|
17 | 26 |
|
18 | 27 | logger = setup_logging()
|
19 | 28 |
|
| 29 | +def get_timeout(config: dict) -> int: |
| 30 | + mode = config.get("mode") |
| 31 | + sec_map = { |
| 32 | + SubmissionMode.TEST.value: config.get("test_timeout"), |
| 33 | + SubmissionMode.BENCHMARK.value: config.get("benchmark_timeout"), |
| 34 | + SubmissionMode.LEADERBOARD.value: config.get("ranked_timeout"), |
| 35 | + } |
| 36 | + seconds = sec_map.get(mode) or DEFAULT_GITHUB_TIMEOUT_MINUTES * 60 |
| 37 | + return math.ceil(seconds / 60) |
20 | 38 |
|
21 | 39 | class GitHubLauncher(Launcher):
|
22 | 40 | def __init__(self, repo: str, token: str):
|
@@ -70,7 +88,13 @@ async def run_submission(
|
70 | 88 |
|
71 | 89 | await status.push("⏳ Waiting for workflow to start...")
|
72 | 90 | logger.info("Waiting for workflow to start...")
|
73 |
| - await run.wait_for_completion(lambda x: self.wait_callback(x, status)) |
| 91 | + |
| 92 | + timeout = get_timeout(config) + TIMEOUT_BUFFER_MINUTES |
| 93 | + logger.info(f"Waiting for workflow to complete... (timeout: {timeout} minutes)") |
| 94 | + await run.wait_for_completion( |
| 95 | + lambda x: self.wait_callback(x, status), |
| 96 | + timeout_minutes=timeout |
| 97 | + ) |
74 | 98 | await status.update(f"Workflow [{run.run_id}]({run.html_url}) completed")
|
75 | 99 | logger.info(f"Workflow [{run.run_id}]({run.html_url}) completed")
|
76 | 100 | await status.push("Downloading artifacts...")
|
|
0 commit comments