Skip to content

Commit 13ce781

Browse files
committed
add github notification
ghstack-source-id: 65127a4 Pull-Request: #7096
1 parent 35f802a commit 13ce781

File tree

4 files changed

+124
-25
lines changed

4 files changed

+124
-25
lines changed

aws/lambda/benchmark_regression_summary_report/common/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# their own benchmark regression config, currenlty place
1616
# here for lambda
1717

18-
1918
COMPILER_BENCHMARK_CONFIG = BenchmarkConfig(
2019
name="Compiler Benchmark Regression",
2120
id="compiler_regression",
@@ -67,7 +66,7 @@
6766
"compression_ratio": RegressionPolicy(
6867
name="compression_ratio",
6968
condition="greater_equal",
70-
threshold=0.9,
69+
threshold=0.95,
7170
baseline_aggregation="max",
7271
),
7372
},

aws/lambda/benchmark_regression_summary_report/common/config_model.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import json
44
from dataclasses import dataclass, field
55
from datetime import timedelta
6-
from typing import Any, Dict, Literal, Optional
6+
from typing import Any, ClassVar, Dict, Literal, Optional
77

8+
import requests
89
from jinja2 import Environment, meta, Template
910

1011

@@ -154,13 +155,60 @@ def is_violation(self, value: float, baseline: float) -> bool:
154155

155156

156157
@dataclass
157-
class Policy:
158-
frequency: Frequency
159-
range: RangeConfig
160-
metrics: Dict[str, RegressionPolicy]
158+
class BaseNotificationConfig:
159+
# subclasses override this
160+
type_tag: ClassVar[str] = ""
161+
162+
@classmethod
163+
def matches(cls, d: Dict[str, Any]) -> bool:
164+
return d.get("type") == cls.type_tag
165+
166+
167+
@dataclass
168+
class GitHubNotificationConfig(BaseNotificationConfig):
169+
type_tag: ClassVar[str] = "github"
170+
171+
# actual fields
172+
type: str = "github"
173+
repo: str = "" # e.g. "owner/repo"
174+
issue_number: str = "" # store as str for simplicity
175+
176+
@classmethod
177+
def from_dict(cls, d: Dict[str, Any]) -> "GitHubNotificationConfig":
178+
# support 'issue' alias
179+
issue = d.get("issue_number") or d.get("issue") or ""
180+
return cls(
181+
type="github",
182+
repo=d.get("repo", ""),
183+
issue_number=str(issue),
184+
)
185+
186+
def create_github_comment(self, body: str, github_token: str) -> Dict[str, Any]:
187+
url = f"https://api.github.com/repos/{self.repo}/issues/{self.issue_number}/comments"
188+
headers = {
189+
"Authorization": f"token {github_token}",
190+
"Accept": "application/vnd.github+json",
191+
"User-Agent": "bench-reporter/1.0",
192+
}
193+
resp = requests.post(url, headers=headers, json={"body": body})
194+
resp.raise_for_status()
195+
return resp.json()
196+
161197

162-
# TODO(elainewy): add notification config
163-
notification_config: Optional[Dict[str, Any]] = None
198+
@dataclass
199+
class Policy:
200+
frequency: "Frequency"
201+
range: "RangeConfig"
202+
metrics: Dict[str, "RegressionPolicy"]
203+
204+
notification_config: Optional[dict[str, Any]] = None
205+
206+
def get_github_notification_config(self) -> Optional[GitHubNotificationConfig]:
207+
if not self.notification_config:
208+
return None
209+
if self.notification_config.get("type") != "github":
210+
return None
211+
return GitHubNotificationConfig.from_dict(self.notification_config)
164212

165213

166214
# -------- Top-level benchmark regression config --------

aws/lambda/benchmark_regression_summary_report/common/report_manager.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,36 @@ def __init__(
7171
self.db_table_name = db_table_name
7272
self.id = str(uuid.uuid4())
7373

74-
def run(self, cc: clickhouse_connect.driver.client.Client) -> None:
74+
def run(
75+
self, cc: clickhouse_connect.driver.client.Client, github_token: str
76+
) -> None:
7577
try:
7678
self.insert_to_db(cc)
7779
except Exception as e:
7880
logger.error(f"failed to insert report to db, error: {e}")
7981
raise
82+
self.notify_github_comment(github_token)
83+
84+
def notify_github_comment(self, github_token: str):
85+
if self.status != "regression":
86+
logger.info(
87+
"[%s] no regression found, skip notification",
88+
self.config_id,
89+
)
90+
return
91+
92+
github_notification = self.config.policy.get_github_notification_config()
93+
if not github_notification:
94+
logger.info(
95+
"[%s] no github notification config found, skip notification",
96+
self.config_id,
97+
)
98+
return
99+
logger.info("[%s] prepareing content", self.config_id)
100+
content = self._to_markdoown()
101+
logger.info("[%s] create comment to github issue", self.config_id)
102+
github_notification.create_github_comment(content, github_token)
103+
logger.info("[%s] done. comment is sent to github", self.config_id)
80104

81105
def _to_markdoown(self):
82106
md = Template(REPORT_MD_TEMPLATE, trim_blocks=True, lstrip_blocks=True).render(
@@ -108,21 +132,19 @@ def insert_to_db(
108132
# ---- 转 UTC,并格式成 ClickHouse 友好的 'YYYY-MM-DD HH:MM:SS' ----
109133
aware = dt.datetime.fromisoformat(latest_ts_str.replace("Z", "+00:00"))
110134
utc_naive = aware.astimezone(dt.timezone.utc).replace(tzinfo=None)
111-
last_record_ts = utc_naive.strftime(
112-
"%Y-%m-%d %H:%M:%S"
113-
) # 给 {DateTime64(0)} 用
135+
last_record_ts = utc_naive.strftime("%Y-%m-%d %H:%M:%S")
114136

115137
report_json = json.dumps(
116138
self.report_data, ensure_ascii=False, separators=(",", ":"), default=str
117139
)
118140

119141
params = {
120-
"id": str(self.id), # 列是 UUID,用 {id:UUID}
142+
"id": str(self.id),
121143
"report_id": self.config_id,
122144
"type": self.type,
123145
"status": self.status,
124146
"last_record_commit": self.latest_meta_info.get("commit", ""),
125-
"last_record_ts": last_record_ts, # 已是 UTC,无时区
147+
"last_record_ts": last_record_ts,
126148
"regression_count": int(self.regression_summary.get("regression_count", 0)),
127149
"insufficient_data_count": int(
128150
self.regression_summary.get("insufficient_data_count", 0)
@@ -139,7 +161,7 @@ def insert_to_db(
139161
"[%s]inserting benchmark regression report(%s)", self.config_id, self.id
140162
)
141163

142-
# INSERT ... SELECT ... FROM system.one + NOT EXISTS 保护
164+
# INSERT ... SELECT ... FROM system.one + NOT EXISTS protection
143165
cc.query(
144166
f"""
145167
INSERT INTO {table} (

aws/lambda/benchmark_regression_summary_report/lambda_function.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
logger.setLevel("INFO")
3232

3333
ENVS = {
34-
"GITHUB_ACCESS_TOKEN": os.getenv("GITHUB_ACCESS_TOKEN", ""),
34+
"GITHUB_TOKEN": os.getenv("GITHUB_TOKEN", ""),
3535
"CLICKHOUSE_ENDPOINT": os.getenv("CLICKHOUSE_ENDPOINT", ""),
3636
"CLICKHOUSE_PASSWORD": os.getenv("CLICKHOUSE_PASSWORD", ""),
3737
"CLICKHOUSE_USERNAME": os.getenv("CLICKHOUSE_USERNAME", ""),
@@ -49,6 +49,7 @@ def get_clickhouse_client(
4949
host: str, user: str, password: str
5050
) -> clickhouse_connect.driver.client.Client:
5151
# for local testing only, disable SSL verification
52+
logger.info("get_clickhouse_client ...")
5253
return clickhouse_connect.get_client(
5354
host=host, user=user, password=password, secure=True, verify=False
5455
)
@@ -115,12 +116,13 @@ def log_error(msg: str):
115116

116117
# check if the current time is > policy's time_delta + previous record_ts from summary_table
117118
report_freq = config.policy.frequency
119+
118120
should_generate = self._should_generate_report(
119121
cc, end_time, config_id, report_freq
120122
)
121123
if not should_generate:
122124
log_info(
123-
f"Skip generate report for time:{end_time} with frequency {report_freq.get_text()}, no data found",
125+
"Skip generate report",
124126
)
125127
return
126128
else:
@@ -210,7 +212,8 @@ def get_baseline(self, config: BenchmarkConfig, end_time: dt.datetime):
210212
)
211213

212214
logger.info(
213-
"found %s # of data, with time range %s",
215+
"[%s] found %s # of data, with time range %s",
216+
config.id,
214217
len(raw_data.time_series),
215218
raw_data.time_range,
216219
)
@@ -325,6 +328,7 @@ def _get_latest_record_ts(
325328
""",
326329
parameters={"config_id": config_id},
327330
)
331+
328332
if not res.result_rows or res.result_rows[0][0] is None:
329333
return None
330334
latest: dt.datetime = res.result_rows[0][
@@ -337,16 +341,42 @@ def _get_latest_record_ts(
337341

338342
freq_delta = f.to_timedelta()
339343
latest_record_ts = _get_latest_record_ts(cc, config_id)
340-
341344
# No report exists yet, generate
342345
if not latest_record_ts:
346+
logger.info("[%s] no latest record ts from db for the config_id", config_id)
343347
return True
348+
logger.info(
349+
"[%s] found latest record ts from db %s", config_id, latest_record_ts
350+
)
344351
end_utc = (
345352
end_time if end_time.tzinfo else end_time.replace(tzinfo=dt.timezone.utc)
346353
)
347354
end_utc = end_utc.astimezone(dt.timezone.utc)
348-
cutoff = end_time - freq_delta
349-
return latest_record_ts < cutoff
355+
time_boundary = latest_record_ts + freq_delta
356+
should_generate = end_time > time_boundary
357+
358+
if not should_generate:
359+
logger.info(
360+
"[%s][frequency(%s)] skip generate report. end_time(%s) must greater than "
361+
"time_boundary(%s) based on latest_record_ts(%s)",
362+
config_id,
363+
f.get_text(),
364+
end_time,
365+
time_boundary,
366+
latest_record_ts,
367+
)
368+
else:
369+
logger.info(
370+
"[%s][frequency(%s)] plan to generate report. end_time(%s) is greater than "
371+
"time_boundary(%s) based on latest_record_ts(%s)",
372+
config_id,
373+
f.get_text(),
374+
end_time,
375+
time_boundary,
376+
latest_record_ts,
377+
)
378+
379+
return should_generate
350380

351381

352382
class WorkerPoolHandler:
@@ -416,7 +446,7 @@ def main(
416446
2. call WorkerPoolHandler to geneterate and write histogram data for each interval in parallel
417447
"""
418448
if not github_access_token:
419-
raise ValueError("Missing environment variable GITHUB_ACCESS_TOKEN")
449+
raise ValueError("Missing environment variable GITHUB_TOKEN")
420450

421451
# get time intervals.
422452
logger.info("[Main] start work ....")
@@ -435,7 +465,7 @@ def lambda_handler(event: Any, context: Any) -> None:
435465
"""
436466
main(
437467
None,
438-
github_access_token=ENVS["GITHUB_ACCESS_TOKEN"],
468+
github_access_token=ENVS["GITHUB_TOKEN"],
439469
)
440470
return
441471

@@ -467,7 +497,7 @@ def parse_args() -> argparse.Namespace:
467497
parser.add_argument(
468498
"--github-access-token",
469499
type=str,
470-
default=ENVS["GITHUB_ACCESS_TOKEN"],
500+
default=ENVS["GITHUB_TOKEN"],
471501
help="the github access token to access github api",
472502
)
473503
parser.add_argument(

0 commit comments

Comments
 (0)