Skip to content

Commit be13115

Browse files
committed
add db writes
ghstack-source-id: 5d0d2b5 Pull-Request: #7095
1 parent 947296c commit be13115

File tree

5 files changed

+314
-14
lines changed

5 files changed

+314
-14
lines changed

aws/lambda/benchmark_regression_summary_report/common/benchmark_time_series_api_model.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import datetime as dt
12
from dataclasses import dataclass, field
2-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
34

45
import requests
56

@@ -62,3 +63,24 @@ def from_request(
6263
except Exception as e:
6364
raise RuntimeError(f"Malformed API payload: {e}")
6465
return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr))
66+
67+
68+
def get_latest_meta_info(
69+
time_series: List[BenchmarkTimeSeriesItem],
70+
) -> Optional[dict[str, Any]]:
71+
if not time_series:
72+
return None
73+
74+
pts = [p for s in time_series for p in s.data]
75+
latest = max(
76+
pts,
77+
key=lambda p: dt.datetime.fromisoformat(
78+
p["granularity_bucket"].replace("Z", "+00:00")
79+
),
80+
)
81+
return {
82+
"commit": latest.get("commit", ""),
83+
"branch": latest.get("branch", ""),
84+
"timestamp": latest.get("granularity_bucket", ""),
85+
"workflow_id": latest.get("workflow_id", ""),
86+
}

aws/lambda/benchmark_regression_summary_report/common/config_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def to_timedelta(self) -> timedelta:
3636
raise ValueError(f"Unsupported unit: {self.unit}")
3737

3838
def get_text(self):
39-
return f"{self.value} {self.unit}"
39+
return f"{self.value}_{self.unit}"
4040

4141

4242
# -------- Source --------
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import dataclasses
2+
import datetime as dt
3+
import json
4+
import logging
5+
import uuid
6+
from typing import Any, Dict, List
7+
8+
import clickhouse_connect
9+
from common.config_model import BenchmarkConfig, Frequency
10+
from common.regression_utils import PerGroupResult
11+
from jinja2 import Template
12+
13+
14+
logger = logging.getLogger()
15+
16+
17+
REPORT_MD_TEMPLATE = """# Benchmark Report {{id}}
18+
config_id: `{{ report_id }}`
19+
20+
> **Status:** {{ status }} · **Frequency:** {{ frequency }}
21+
22+
## Latest
23+
- **Timestamp:** `{{ latest.timestamp | default('') }}`
24+
- **Commit:** `{{ (latest.commit | default(''))[:12] }}`
25+
- **Branch:** `{{ latest.branch | default('') }}`
26+
- **Workflow ID:** `{{ latest.workflow_id | default('') }}`
27+
28+
## Summary
29+
| Metric | Value |
30+
| :-- | --: |
31+
| Total | {{ summary.total_count | default(0) }} |
32+
| Regressions | {{ summary.regression_count | default(0) }} |
33+
| Suspicious | {{ summary.suspicious_count | default(0) }} |
34+
| No Regression | {{ summary.no_regression_count | default(0) }} |
35+
| Insufficient Data | {{ summary.insufficient_data_count | default(0) }} |
36+
"""
37+
38+
39+
class ReportManager:
40+
"""
41+
handles db insertion and notification processing
42+
"""
43+
44+
def __init__(
45+
self,
46+
db_table_name: str,
47+
config_id: str,
48+
config: BenchmarkConfig,
49+
regression_summary: Dict[str, Any],
50+
latest_meta_info: Dict[str, Any],
51+
result: List[PerGroupResult],
52+
type: str = "general",
53+
repo: str = "pytorch/pytorch",
54+
):
55+
self.regression_summary = regression_summary
56+
self.regression_result = result
57+
self.config_id = config_id
58+
self.config = config
59+
self.status = self._resolve_status(regression_summary)
60+
self.latest_meta_info = self._validate_latest_meta_info(latest_meta_info)
61+
self.report_data = self._to_report_data(
62+
config_id=config_id,
63+
summary=self.regression_summary,
64+
report=self.regression_result,
65+
latest=self.latest_meta_info,
66+
status=self.status,
67+
frequency=self.config.policy.frequency,
68+
)
69+
self.type = type
70+
self.repo = repo
71+
self.db_table_name = db_table_name
72+
self.id = str(uuid.uuid4())
73+
74+
def run(self, cc: clickhouse_connect.driver.client.Client) -> None:
75+
try:
76+
self.insert_to_db(cc)
77+
except Exception as e:
78+
logger.error(f"failed to insert report to db, error: {e}")
79+
raise
80+
81+
def _to_markdoown(self):
82+
md = Template(REPORT_MD_TEMPLATE, trim_blocks=True, lstrip_blocks=True).render(
83+
id=self.id,
84+
status=self.status,
85+
report_id=self.config_id,
86+
summary=self.regression_summary,
87+
latest=self.latest_meta_info,
88+
frequency=self.config.policy.frequency.get_text(),
89+
)
90+
return md
91+
92+
def insert_to_db(
93+
self,
94+
cc: clickhouse_connect.driver.client.Client,
95+
) -> None:
96+
logger.info(
97+
"[%s]prepare data for db insertion report (%s)...", self.config_id, self.id
98+
)
99+
100+
table = self.db_table_name
101+
102+
latest_ts_str = self.latest_meta_info.get("timestamp")
103+
if not latest_ts_str:
104+
raise ValueError(
105+
f"timestamp from latest is required, latest is {self.latest_meta_info}"
106+
)
107+
108+
# ---- 转 UTC,并格式成 ClickHouse 友好的 'YYYY-MM-DD HH:MM:SS' ----
109+
aware = dt.datetime.fromisoformat(latest_ts_str.replace("Z", "+00:00"))
110+
utc_naive = aware.astimezone(dt.timezone.utc).replace(tzinfo=None)
111+
last_record_ts = utc_naive.strftime(
112+
"%Y-%m-%d %H:%M:%S"
113+
) # 给 {DateTime64(0)} 用
114+
115+
report_json = json.dumps(
116+
self.report_data, ensure_ascii=False, separators=(",", ":"), default=str
117+
)
118+
119+
params = {
120+
"id": str(self.id), # 列是 UUID,用 {id:UUID}
121+
"report_id": self.config_id,
122+
"type": self.type,
123+
"status": self.status,
124+
"last_record_commit": self.latest_meta_info.get("commit", ""),
125+
"last_record_ts": last_record_ts, # 已是 UTC,无时区
126+
"regression_count": int(self.regression_summary.get("regression_count", 0)),
127+
"insufficient_data_count": int(
128+
self.regression_summary.get("insufficient_data_count", 0)
129+
),
130+
"suspected_regression_count": int(
131+
self.regression_summary.get("suspicious_count", 0)
132+
),
133+
"total_count": int(self.regression_summary.get("total_count", 0)),
134+
"repo": self.repo,
135+
"report_json": report_json,
136+
}
137+
138+
logger.info(
139+
"[%s]inserting benchmark regression report(%s)", self.config_id, self.id
140+
)
141+
142+
# 纯 INSERT ... SELECT ... FROM system.one + NOT EXISTS 保护
143+
cc.query(
144+
f"""
145+
INSERT INTO {table} (
146+
id,
147+
report_id,
148+
last_record_ts,
149+
last_record_commit,
150+
`type`,
151+
status,
152+
regression_count,
153+
insufficient_data_count,
154+
suspected_regression_count,
155+
total_count,
156+
repo,
157+
report
158+
)
159+
SELECT
160+
{{id:UUID}},
161+
{{report_id:String}},
162+
{{last_record_ts:DateTime64(0)}},
163+
{{last_record_commit:String}},
164+
{{type:String}},
165+
{{status:String}},
166+
{{regression_count:UInt32}},
167+
{{insufficient_data_count:UInt32}},
168+
{{suspected_regression_count:UInt32}},
169+
{{total_count:UInt32}},
170+
{{repo:String}},
171+
{{report_json:String}}
172+
FROM system.one
173+
WHERE NOT EXISTS (
174+
SELECT 1
175+
FROM {table}
176+
WHERE report_id = {{report_id:String}}
177+
AND `type` = {{type:String}}
178+
AND repo = {{repo:String}}
179+
AND stamp = toDate({{last_record_ts:DateTime64(0)}})
180+
);
181+
""",
182+
parameters=params,
183+
)
184+
185+
logger.info(
186+
"[%s] Done. inserted benchmark regression report(%s)",
187+
self.config_id,
188+
self.id,
189+
)
190+
191+
def _resolve_status(self, regression_summary: Dict[str, Any]) -> str:
192+
status = (
193+
"regression"
194+
if regression_summary.get("regression_count", 0) > 0
195+
else "suspicious"
196+
if regression_summary.get("suspicious_count", 0) > 0
197+
else "no_regression"
198+
)
199+
return status
200+
201+
def _validate_latest_meta_info(
202+
self, latest_meta_info: Dict[str, Any]
203+
) -> Dict[str, Any]:
204+
latest_commit = latest_meta_info.get("commit")
205+
if not latest_commit:
206+
raise ValueError(
207+
f"missing commit from latest is required, latest is {latest_meta_info}"
208+
)
209+
lastest_ts_str = latest_meta_info.get("timestamp")
210+
if not lastest_ts_str:
211+
raise ValueError(
212+
f"timestamp from latest is required, latest is {latest_meta_info}"
213+
)
214+
return latest_meta_info
215+
216+
def _to_report_data(
217+
self,
218+
config_id: str,
219+
summary: Dict[str, Any],
220+
report: List[Any], # List[PerGroupResult] or dicts
221+
latest: dict[str, Any], # {"commit","branch","timestamp","workflow_id"}
222+
status: str,
223+
frequency: Frequency,
224+
) -> dict[str, Any]:
225+
latest_commit = latest.get("commit")
226+
if not latest_commit:
227+
raise ValueError(
228+
f"missing commit from latest is required, latest is {latest}"
229+
)
230+
lastest_ts_str = latest.get("timestamp")
231+
if not lastest_ts_str:
232+
raise ValueError(f"timestamp from latest is required, latest is {latest}")
233+
234+
def to_dict(x): # handle dataclass or dict/object
235+
if dataclasses.is_dataclass(x):
236+
return dataclasses.asdict(x)
237+
if isinstance(x, dict):
238+
return x
239+
return vars(x) if hasattr(x, "__dict__") else {"value": str(x)}
240+
241+
return {
242+
"status": status,
243+
"report_id": config_id,
244+
"summary": summary,
245+
"latest": latest,
246+
"details": [to_dict(x) for x in report],
247+
"frequency": frequency.get_text(),
248+
}

0 commit comments

Comments
 (0)