|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import json
|
4 |
| -from dataclasses import dataclass, field |
| 4 | +import requests |
| 5 | +from dataclasses import dataclass, field, fields |
5 | 6 | from datetime import timedelta
|
6 |
| -from typing import Any, Dict, Literal, Optional |
| 7 | +from typing import Any, ClassVar, Dict, Literal, Optional, Type |
7 | 8 |
|
8 | 9 | from jinja2 import Environment, meta, Template
|
9 | 10 |
|
@@ -152,16 +153,66 @@ def is_violation(self, value: float, baseline: float) -> bool:
|
152 | 153 |
|
153 | 154 | raise ValueError(f"Unknown condition: {self.condition}")
|
154 | 155 |
|
| 156 | +class BaseNotificationConfig: |
| 157 | + # every subclass must override this |
| 158 | + type_tag: ClassVar[str] |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: |
| 162 | + # pick only known fields for this dataclass |
| 163 | + kwargs = {f.name: d.get(f.name) for f in fields(cls)} |
| 164 | + return cls(**kwargs) # type: ignore |
| 165 | + |
| 166 | + @classmethod |
| 167 | + def matches(cls, d: Dict[str, Any]) -> bool: |
| 168 | + return d.get("type") == cls.type_tag |
| 169 | + |
| 170 | + |
| 171 | +@dataclass |
| 172 | +class GitHubNotificationConfig(BaseNotificationConfig): |
| 173 | + type: str = "github" |
| 174 | + repo: str = "" |
| 175 | + issue_number: str = "" |
| 176 | + type_tag: ClassVar[str] = "github" |
| 177 | + |
| 178 | + def create_github_comment(self, body: str, github_token: str) -> Dict[str, Any]: |
| 179 | + """ |
| 180 | + Create a new comment on a GitHub issue. |
| 181 | + Args: |
| 182 | + notification_config: dict with keys: |
| 183 | + - type: must be "github" |
| 184 | + - repo: "owner/repo" |
| 185 | + - issue: issue number (string or int) |
| 186 | + body: text of the comment |
| 187 | + token: GitHub personal access token or GitHub Actions token |
| 188 | +
|
| 189 | + Returns: |
| 190 | + The GitHub API response as a dict (JSON). |
| 191 | + """ |
| 192 | + url = f"https://api.github.com/repos/{self.repo}/issues/{self.issue_number}/comments" |
| 193 | + headers = { |
| 194 | + "Authorization": f"token {github_token}", |
| 195 | + "Accept": "application/vnd.github+json", |
| 196 | + "User-Agent": "bench-reporter/1.0", |
| 197 | + } |
| 198 | + resp = requests.post(url, headers=headers, json={"body": body}) |
| 199 | + resp.raise_for_status() |
| 200 | + return resp.json() |
| 201 | + |
155 | 202 |
|
156 | 203 | @dataclass
|
157 | 204 | class Policy:
|
158 | 205 | frequency: Frequency
|
159 | 206 | range: RangeConfig
|
160 | 207 | metrics: Dict[str, RegressionPolicy]
|
161 | 208 |
|
162 |
| - # TODO(elainewy): add notification config |
163 | 209 | notification_config: Optional[Dict[str, Any]] = None
|
164 | 210 |
|
| 211 | + def get_github_notification_config(self) -> Optional[GitHubNotificationConfig]: |
| 212 | + if not self.notification_config: |
| 213 | + return None |
| 214 | + return notification_from_dict(self.notification_config) # type: ignore |
| 215 | + |
165 | 216 |
|
166 | 217 | # -------- Top-level benchmark regression config --------
|
167 | 218 | @dataclass
|
|
0 commit comments