Skip to content

Commit 5caf5c1

Browse files
authored
Fix Timeout Bug (#280)
* use workflows timeout * remove magic number * lint * Update nvidia_workflow.yml * Update nvidia_workflow.yml * fix verify tests * make reviewable * add buffer
1 parent a4aaf08 commit 5caf5c1

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

.github/workflows/nvidia_workflow.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
shell: bash
3030
run: |
3131
# Extract the payload content without printing it
32+
apt-get update && apt-get install -y jq
3233
PAYLOAD=$(jq -r '.inputs.payload' $GITHUB_EVENT_PATH)
3334
3435
# Apply mask to the extracted content

src/discord-cluster-manager/consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class RankCriterion(Enum):
145145
]
146146
MODAL_CUDA_INCLUDE_DIRS = ["/ThunderKittens/include"]
147147

148+
DEFAULT_GITHUB_TIMEOUT_MINUTES = 10 # Default timeout for GitHub launcher in minutes
149+
148150
NVIDIA_REQUIREMENTS = """
149151
numpy
150152
torch
@@ -157,3 +159,6 @@ class RankCriterion(Enum):
157159
--index-url https://download.pytorch.org/whl/rocm6.2.4
158160
torch
159161
"""
162+
163+
# A buffer for timeouts to account for github setup time
164+
TIMEOUT_BUFFER_MINUTES = 2

src/discord-cluster-manager/launchers/github.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import asyncio
22
import datetime
33
import json
4+
import math
45
import pprint
56
import tempfile
67
import zipfile
78
from typing import Awaitable, Callable, Optional
89

910
import requests
10-
from consts import AMD_REQUIREMENTS, GPU, NVIDIA_REQUIREMENTS, GitHubGPU
11+
from consts import (
12+
AMD_REQUIREMENTS,
13+
DEFAULT_GITHUB_TIMEOUT_MINUTES,
14+
GPU,
15+
NVIDIA_REQUIREMENTS,
16+
TIMEOUT_BUFFER_MINUTES,
17+
GitHubGPU,
18+
SubmissionMode,
19+
)
1120
from github import Github, UnknownObjectException, WorkflowRun
1221
from report import RunProgressReporter
1322
from run_eval import CompileResult, EvalResult, FullResult, RunResult, SystemInfo
@@ -17,6 +26,15 @@
1726

1827
logger = setup_logging()
1928

29+
def get_timeout(config: dict) -> int:
30+
mode = config.get("mode")
31+
sec_map = {
32+
SubmissionMode.TEST.value: config.get("test_timeout"),
33+
SubmissionMode.BENCHMARK.value: config.get("benchmark_timeout"),
34+
SubmissionMode.LEADERBOARD.value: config.get("ranked_timeout"),
35+
}
36+
seconds = sec_map.get(mode) or DEFAULT_GITHUB_TIMEOUT_MINUTES * 60
37+
return math.ceil(seconds / 60)
2038

2139
class GitHubLauncher(Launcher):
2240
def __init__(self, repo: str, token: str):
@@ -70,7 +88,13 @@ async def run_submission(
7088

7189
await status.push("⏳ Waiting for workflow to start...")
7290
logger.info("Waiting for workflow to start...")
73-
await run.wait_for_completion(lambda x: self.wait_callback(x, status))
91+
92+
timeout = get_timeout(config) + TIMEOUT_BUFFER_MINUTES
93+
logger.info(f"Waiting for workflow to complete... (timeout: {timeout} minutes)")
94+
await run.wait_for_completion(
95+
lambda x: self.wait_callback(x, status),
96+
timeout_minutes=timeout
97+
)
7498
await status.update(f"Workflow [{run.run_id}]({run.html_url}) completed")
7599
logger.info(f"Workflow [{run.run_id}]({run.html_url}) completed")
76100
await status.push("Downloading artifacts...")

0 commit comments

Comments
 (0)