Skip to content

Commit 9ae4838

Browse files
authored
Improve the time series api + add policy for regression (#7156)
For API: - Add model filters during query - Add format options table, and raw - add API hook method for frontend add listCommits api For regression lambda - Add regression policy for compilation latency (if new value> 1.05 x baseline, consider as regression) - Change the data format to match with the api
1 parent b51743d commit 9ae4838

File tree

14 files changed

+380
-84
lines changed

14 files changed

+380
-84
lines changed

aws/lambda/benchmark_regression_summary_report/common/benchmark_time_series_api_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def from_request(
5858
tr = TimeRange(**payload["data"]["time_range"])
5959
ts = [
6060
BenchmarkTimeSeriesItem(**item)
61-
for item in payload["data"]["time_series"]
61+
for item in payload["data"]["data"]["time_series"]
6262
]
6363
except Exception as e:
6464
raise RuntimeError(f"Malformed API payload: {e}")

aws/lambda/benchmark_regression_summary_report/common/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
api_endpoint_params_template="""
2828
{
2929
"name": "compiler_precompute",
30+
"response_formats":["time_series"],
3031
"query_params": {
3132
"commits": [],
3233
"compilers": [],
@@ -50,7 +51,7 @@
5051
policy=Policy(
5152
frequency=Frequency(value=1, unit="days"),
5253
range=RangeConfig(
53-
baseline=DayRangeWindow(value=7),
54+
baseline=DayRangeWindow(value=5),
5455
comparison=DayRangeWindow(value=2),
5556
),
5657
metrics={
@@ -72,6 +73,12 @@
7273
threshold=0.95,
7374
baseline_aggregation="max",
7475
),
76+
"compilation_latency": RegressionPolicy(
77+
name="compilation_latency",
78+
condition="less_equal",
79+
threshold=1.05,
80+
baseline_aggregation="min",
81+
),
7582
},
7683
notification_config={
7784
"type": "github",
Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,45 @@
11
SELECT DISTINCT
22
replaceOne(head_branch, 'refs/heads/', '') AS branch,
33
head_sha AS commit,
4-
workflow_id AS id,
5-
timestamp
4+
workflow_id,
5+
toDate(fromUnixTimestamp(timestamp), 'UTC') AS date
66
FROM benchmark.oss_ci_benchmark_torchinductor
77
PREWHERE
88
timestamp >= toUnixTimestamp({startTime: DateTime64(3)})
99
AND timestamp < toUnixTimestamp({stopTime: DateTime64(3)})
1010
WHERE
11+
-- optional branches
1112
(
1213
has(
1314
{branches: Array(String)},
1415
replaceOne(head_branch, 'refs/heads/', '')
1516
)
1617
OR empty({branches: Array(String)})
1718
)
19+
-- optional suites
1820
AND (
19-
has({suites: Array(String) }, suite)
20-
OR empty({suites: Array(String) })
21+
has({suites: Array(String)}, suite)
22+
OR empty({suites: Array(String)})
2123
)
22-
AND benchmark_dtype = {dtype: String}
23-
AND benchmark_mode = {mode: String}
24-
AND device = {device: String}
25-
AND multiSearchAnyCaseInsensitive(arch, {arch: Array(String)})
26-
ORDER BY timestamp
24+
-- optional dtype
25+
AND (
26+
benchmark_dtype = {dtype: String}
27+
OR empty({dtype: String})
28+
)
29+
-- optional mode
30+
AND (
31+
benchmark_mode = {mode: String}
32+
OR empty({mode: String})
33+
)
34+
-- optional device
35+
AND (
36+
device = {device: String}
37+
OR empty({device: String})
38+
)
39+
-- optional arch (array param); if empty array, skip filter
40+
AND (
41+
multiSearchAnyCaseInsensitive(arch, {arch: Array(String)})
42+
OR empty({arch: Array(String)})
43+
)
44+
ORDER BY branch, timestamp
2745
SETTINGS session_timezone = 'UTC';

torchci/clickhouse_queries/compilers_benchmark_api_query/params.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"dtype": "String",
99
"granularity": "String",
1010
"mode": "String",
11-
"suites": "Array(String)"
11+
"suites": "Array(String)",
12+
"models": "Array(String)"
1213
},
1314
"tests": []
1415
}

torchci/clickhouse_queries/compilers_benchmark_api_query/query.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ WHERE
3030
has({suites: Array(String) }, suite)
3131
OR empty({suites: Array(String) })
3232
)
33+
AND (
34+
has({models: Array(String)}, model_name)
35+
OR empty({models: Array(String) })
36+
)
3337
AND benchmark_dtype = {dtype: String}
3438
AND benchmark_mode = {mode: String}
3539
AND device = {device: String}

torchci/components/benchmark/compilers/SummaryPanel.tsx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,6 @@ export function SummaryPanel({
470470
if (l === r) {
471471
return "";
472472
}
473-
474473
// Decreasing more than x%
475474
if (r - l > RELATIVE_THRESHOLD * r) {
476475
return styles.ok;

torchci/lib/benchmark/api_helper/compilers/get_compiler_benchmark_data.ts

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ import {
77
} from "./helpers/common";
88
import { toGeneralCompilerData } from "./helpers/general";
99
import { toPrecomputeCompilerData } from "./helpers/precompute";
10-
import { CompilerQueryType } from "./type";
10+
import {
11+
CompilerQueryType,
12+
defaultGetTimeSeriesInputs,
13+
defaultListCommitsInputs,
14+
} from "./type";
1115
//["x86_64","NVIDIA A10G","NVIDIA H100 80GB HBM3"]
1216
const COMPILER_BENCHMARK_TABLE_NAME = "compilers_benchmark_api_query";
1317
const COMPILER_BENCHMARK_COMMITS_TABLE_NAME =
@@ -16,7 +20,7 @@ const COMPILER_BENCHMARK_COMMITS_TABLE_NAME =
1620
export async function getCompilerBenchmarkData(
1721
inputparams: any,
1822
type: CompilerQueryType = CompilerQueryType.PRECOMPUTE,
19-
format: string = "time_series"
23+
formats: string[] = ["time_series"]
2024
) {
2125
const rows = await getCompilerDataFromClickhouse(inputparams);
2226

@@ -26,55 +30,85 @@ export async function getCompilerBenchmarkData(
2630

2731
switch (type) {
2832
case CompilerQueryType.PRECOMPUTE:
29-
return toPrecomputeCompilerData(rows, format);
33+
return toPrecomputeCompilerData(rows, formats);
3034
case CompilerQueryType.GENERAL:
31-
return toGeneralCompilerData(rows, format);
35+
return toGeneralCompilerData(rows, formats);
3236
default:
3337
throw new Error(`Invalid compiler query type, got ${type}`);
3438
}
3539
}
3640

41+
export async function getCompilerCommits(inputparams: any): Promise<any[]> {
42+
if (!inputparams.startTime || !inputparams.stopTime) {
43+
throw new Error("no start/end time provided in request");
44+
}
45+
const queryParams = {
46+
...defaultListCommitsInputs, // base defaults
47+
...inputparams, // override with caller's values
48+
};
49+
50+
if (queryParams.arch && queryParams.device) {
51+
const arch_list = toQueryArch(inputparams.device, inputparams.arch);
52+
queryParams["arch"] = arch_list;
53+
}
54+
55+
const commit_results = await queryClickhouseSaved(
56+
COMPILER_BENCHMARK_COMMITS_TABLE_NAME,
57+
queryParams
58+
);
59+
return commit_results;
60+
}
61+
3762
async function getCompilerDataFromClickhouse(inputparams: any): Promise<any[]> {
3863
const start = Date.now();
39-
const arch_list = toQueryArch(inputparams.device, inputparams.arch);
40-
inputparams["arch"] = arch_list;
64+
65+
const queryParams = {
66+
...defaultGetTimeSeriesInputs, // base defaults
67+
...inputparams, // override with caller's values
68+
};
69+
70+
if (queryParams.arch && queryParams.device) {
71+
const arch_list = toQueryArch(queryParams.device, queryParams.arch);
72+
queryParams["arch"] = arch_list;
73+
}
4174

4275
// use the startTime and endTime to fetch commits from clickhouse if commits field is not provided
43-
if (!inputparams.commits || inputparams.commits.length == 0) {
44-
if (!inputparams.startTime || !inputparams.stopTime) {
76+
if (!queryParams.commits || queryParams.commits.length == 0) {
77+
if (!queryParams.startTime || !queryParams.stopTime) {
4578
console.log("no commits or start/end time provided in request");
4679
return [];
4780
}
81+
4882
// get commits from clickhouse
4983
const commit_results = await queryClickhouseSaved(
5084
COMPILER_BENCHMARK_COMMITS_TABLE_NAME,
51-
inputparams
85+
queryParams
5286
);
5387
// get unique commits
5488
const unique_commits = [...new Set(commit_results.map((c) => c.commit))];
5589
if (unique_commits.length === 0) {
56-
console.log("no commits found in clickhouse using", inputparams);
90+
console.log("no commits found in clickhouse using", queryParams);
5791
return [];
5892
}
5993

6094
console.log(
61-
"no commits provided in request, found unqiue commits",
62-
unique_commits
95+
`no commits provided in request, searched unqiue commits based on
96+
start/end time unique_commits: ${unique_commits.length}`
6397
);
6498

6599
if (commit_results.length > 0) {
66-
inputparams["commits"] = unique_commits;
100+
queryParams["commits"] = unique_commits;
67101
} else {
68-
console.log(`no commits found in clickhouse using ${inputparams}`);
102+
console.log(`no commits found in clickhouse using ${queryParams}`);
69103
return [];
70104
}
71105
} else {
72-
console.log("commits provided in request", inputparams.commits);
106+
console.log("commits provided in request", queryParams.commits);
73107
}
74108

75109
let rows = await queryClickhouseSaved(
76110
COMPILER_BENCHMARK_TABLE_NAME,
77-
inputparams
111+
queryParams
78112
);
79113
const end = Date.now();
80114
console.log("time to get compiler timeseris data", end - start);
@@ -83,6 +117,8 @@ async function getCompilerDataFromClickhouse(inputparams: any): Promise<any[]> {
83117
return [];
84118
}
85119

120+
console.log("rows from clickhouse", rows[0]);
121+
86122
// extract backend from output in runtime instead of doing it in the query. since it's expensive for regex matching.
87123
// TODO(elainewy): we should add this as a column in the database for less runtime logics.
88124
rows.map((row) => {

torchci/lib/benchmark/api_helper/compilers/helpers/common.ts

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
import { groupByBenchmarkData } from "../../utils";
2-
3-
export function to_table_compiler_data(data: any[]) {
4-
const res = groupByBenchmarkData(
5-
data,
6-
["dtype", "arch", "device", "mode", "workflow_id", "granularity_bucket"],
7-
["metric", "compiler"]
8-
);
9-
return res;
10-
}
11-
121
export function extractBackendSqlStyle(
132
output: string,
143
suite: string,

torchci/lib/benchmark/api_helper/compilers/helpers/general.ts

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const COMPILER_GENERAL_TS_GROUP_KEY = [
1313
"metric",
1414
"mode",
1515
"model",
16+
"branch",
1617
];
1718
const COMPILER_GENERAL_TS_SUB_GROUP_KEY = ["workflow_id"];
1819

@@ -22,6 +23,7 @@ const COMPILER_GENERAL_TABLE_GROUP_KEY = [
2223
"device",
2324
"mode",
2425
"workflow_id",
26+
"branch",
2527
"compiler",
2628
"model",
2729
];
@@ -35,31 +37,40 @@ const COMPILER_GENERAL_TABLE_SUB_GROUP_KEY = ["metric"];
3537
*/
3638
export function toGeneralCompilerData(
3739
rawData: any[],
38-
type: string = "time_series"
40+
formats: string[] = ["time_series"]
3941
) {
4042
const start_ts = new Date(rawData[0].granularity_bucket).getTime();
4143
const end_ts = new Date(
4244
rawData[rawData.length - 1].granularity_bucket
4345
).getTime();
4446

45-
let res: any[] = [];
46-
switch (type) {
47+
let formats_result: any = {};
48+
49+
formats.forEach((format) => {
50+
const data = getformat(rawData, format);
51+
formats_result[format] = data;
52+
});
53+
return toTimeSeriesResponse(formats_result, rawData.length, start_ts, end_ts);
54+
}
55+
56+
function getformat(data: any, format: string) {
57+
switch (format) {
4758
case "time_series":
48-
res = to_time_series_data(
49-
rawData,
59+
return to_time_series_data(
60+
data,
5061
COMPILER_GENERAL_TS_GROUP_KEY,
5162
COMPILER_GENERAL_TS_SUB_GROUP_KEY
5263
);
53-
break;
5464
case "table":
55-
res = groupByBenchmarkData(
56-
rawData,
65+
return groupByBenchmarkData(
66+
data,
5767
COMPILER_GENERAL_TABLE_GROUP_KEY,
5868
COMPILER_GENERAL_TABLE_SUB_GROUP_KEY
5969
);
6070
break;
71+
case "raw":
72+
return data;
6173
default:
6274
throw new Error("Invalid type");
6375
}
64-
return toTimeSeriesResponse(res, rawData.length, start_ts, end_ts);
6576
}

0 commit comments

Comments
 (0)