Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.zip
deployment/
venv/
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List

import requests


# The data class to provide api response model from get_time_series api


@dataclass
class TimeRange:
start: str
end: str


@dataclass
class BenchmarkTimeSeriesItem:
group_info: Dict[str, Any]
num_of_dp: int
data: List[Dict[str, Any]] = field(default_factory=list)


@dataclass
class BenchmarkTimeSeriesApiData:
time_series: List[BenchmarkTimeSeriesItem]
time_range: TimeRange


@dataclass
class BenchmarkTimeSeriesApiResponse:
data: BenchmarkTimeSeriesApiData

@classmethod
def from_request(
cls, url: str, query: dict, timeout: int = 180
) -> "BenchmarkTimeSeriesApiResponse":
"""
Send a POST request and parse into BenchmarkTimeSeriesApiResponse.

Args:
url: API endpoint
query: JSON payload must
timeout: max seconds to wait for connect + response (default: 30)
Returns:
ApiResponse
Raises:
requests.exceptions.RequestException if network/timeout/HTTP error
RuntimeError if the API returns an "error" field or malformed data
"""
resp = requests.post(url, json=query, timeout=timeout)
resp.raise_for_status()
payload = resp.json()

if "error" in payload:
raise RuntimeError(f"API error: {payload['error']}")
try:
tr = TimeRange(**payload["data"]["time_range"])
ts = [
BenchmarkTimeSeriesItem(**item)
for item in payload["data"]["time_series"]
]
except Exception as e:
raise RuntimeError(f"Malformed API payload: {e}")
return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr))
94 changes: 94 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/common/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from common.config_model import (
BenchmarkApiSource,
BenchmarkConfig,
BenchmarkRegressionConfigBook,
DayRangeWindow,
Frequency,
Policy,
RangeConfig,
RegressionPolicy,
)


# Compiler benchmark regression config
# todo(elainewy): eventually each team should configure
# their own benchmark regression config, currenlty place
# here for lambda


COMPILER_BENCHMARK_CONFIG = BenchmarkConfig(
name="Compiler Benchmark Regression",
id="compiler_regression",
source=BenchmarkApiSource(
api_query_url="https://hud.pytorch.org/api/benchmark/get_time_series",
type="benchmark_time_series_api",
# currently we only detect the regression for h100 with dtype bfloat16, and mode inference
# we can extend this to other devices, dtypes and mode in the future
api_endpoint_params_template="""
{
"name": "compiler_precompute",
"query_params": {
"commits": [],
"compilers": [],
"arch": "h100",
"device": "cuda",
"dtype": "bfloat16",
"granularity": "hour",
"mode": "inference",
"startTime": "{{ startTime }}",
"stopTime": "{{ stopTime }}",
"suites": ["torchbench", "huggingface", "timm_models"],
"workflowId": 0,
"branches": ["main"]
}
}
""",
),
# set baseline from past 7 days using avg, and compare with the last 1 day
policy=Policy(
frequency=Frequency(value=1, unit="days"),
range=RangeConfig(
baseline=DayRangeWindow(value=7),
comparison=DayRangeWindow(value=2),
),
metrics={
"passrate": RegressionPolicy(
name="passrate",
condition="greater_equal",
threshold=0.9,
baseline_aggregation="max",
),
"geomean": RegressionPolicy(
name="geomean",
condition="greater_equal",
threshold=0.95,
baseline_aggregation="max",
),
"compression_ratio": RegressionPolicy(
name="compression_ratio",
condition="greater_equal",
threshold=0.9,
baseline_aggregation="max",
),
},
notification_config={
"type": "github",
"repo": "pytorch/test-infra",
"issue": "7081",
},
),
)

BENCHMARK_REGRESSION_CONFIG = BenchmarkRegressionConfigBook(
configs={
"compiler_regression": COMPILER_BENCHMARK_CONFIG,
}
)


def get_benchmark_regression_config(config_id: str) -> BenchmarkConfig:
"""Get benchmark regression config by config id"""
try:
return BENCHMARK_REGRESSION_CONFIG[config_id]
except KeyError:
raise ValueError(f"Invalid config id: {config_id}")
194 changes: 194 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/common/config_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from __future__ import annotations

import json
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Dict, Literal, Optional

from jinja2 import Environment, meta, Template


# -------- Frequency --------
@dataclass(frozen=True)
class Frequency:
"""
The frequency of how often the report should be generated.
The minimum frequency we support is 1 day.
Attributes:
value: Number of units (e.g., 7 for 7 days).
unit: Unit of time, either "days" or "weeks".

Methods:
to_timedelta: Convert frequency into a datetime.timedelta.
get_text: return the frequency in text format
"""

value: int
unit: Literal["days", "weeks"]

def to_timedelta(self) -> timedelta:
"""Convert frequency N days or M weeks into a datetime.timedelta."""
if self.unit == "days":
return timedelta(days=self.value)
elif self.unit == "weeks":
return timedelta(weeks=self.value)
else:
raise ValueError(f"Unsupported unit: {self.unit}")

def get_text(self):
return f"{self.value} {self.unit}"


# -------- Source --------
_JINJA_ENV = Environment(autoescape=False)


@dataclass
class BenchmarkApiSource:
"""
Defines the source of the benchmark data we want to query
api_query_url: the url of the api to query
api_endpoint_params_template: the jinjia2 template of the api endpoint's query params
default_ctx: the default context to use when rendering the api_endpoint_params_template
"""

api_query_url: str
api_endpoint_params_template: str
type: Literal["benchmark_time_series_api", "other"] = "benchmark_time_series_api"
default_ctx: Dict[str, Any] = field(default_factory=dict)

def required_template_vars(self) -> set[str]:
ast = _JINJA_ENV.parse(self.api_endpoint_params_template)
return set(meta.find_undeclared_variables(ast))

def render(self, ctx: Dict[str, Any], strict: bool = True) -> dict:
"""Render with caller-supplied context (no special casing for start/end)."""
merged = {**self.default_ctx, **ctx}

if strict:
required = self.required_template_vars()
missing = required - merged.keys()
if missing:
raise ValueError(f"Missing required vars: {missing}")
rendered = Template(self.api_endpoint_params_template).render(**merged)
return json.loads(rendered)


# -------- Policy: range windows --------
@dataclass
class DayRangeWindow:
value: int
# raw indicates fetch from the source data
source: Literal["raw"] = "raw"


@dataclass
class RangeConfig:
"""
Defines the range of baseline and comparison windows for a given policy.
- baseline: the baseline window that build the baseline value
- comparison: the comparison window that we fetch data from to compare against the baseline value
"""

baseline: DayRangeWindow
comparison: DayRangeWindow

def total_timedelta(self) -> timedelta:
return timedelta(days=self.baseline.value + self.comparison.value)

def comparison_timedelta(self) -> timedelta:
return timedelta(days=self.comparison.value)

def baseline_timedelta(self) -> timedelta:
return timedelta(days=self.baseline.value)


# -------- Policy: metrics --------
@dataclass
class RegressionPolicy:
"""
Defines the policy for a given metric.
- new value muset be {x} baseline value:
- "greater_than": higher is better; new value must be strictly greater to baseline
- "less_than": lower is better; new value must be strictly lower to baseline
- "equal_to": new value should be ~= baseline * threshold within rel_tol
- "greater_equal": higher is better; new value must be greater or equal to baseline
- "less_equal": lower is better; new value must be less or equal to baseline
"""

name: str
condition: Literal[
"greater_than", "less_than", "equal_to", "greater_equal", "less_equal"
]
threshold: float
baseline_aggregation: Literal[
"avg", "max", "min", "p50", "p90", "p95", "latest", "earliest"
] = "max"
rel_tol: float = 1e-3 # used only for "equal_to"

def is_violation(self, value: float, baseline: float) -> bool:
target = baseline * self.threshold

if self.condition == "greater_than":
# value must be strictly greater than target
return value <= target

if self.condition == "greater_equal":
# value must be greater or equal to target
return value < target

if self.condition == "less_than":
# value must be strictly less than target
return value >= target

if self.condition == "less_equal":
# value must be less or equal to target
return value > target

if self.condition == "equal_to":
# |value - target| should be within rel_tol * max(1, |target|)
denom = max(1.0, abs(target))
return abs(value - target) > self.rel_tol * denom

raise ValueError(f"Unknown condition: {self.condition}")


@dataclass
class Policy:
frequency: Frequency
range: RangeConfig
metrics: Dict[str, RegressionPolicy]

# TODO(elainewy): add notification config
notification_config: Optional[Dict[str, Any]] = None


# -------- Top-level benchmark regression config --------
@dataclass
class BenchmarkConfig:
"""
Represents a single benchmark regression configuration.
- BenchmarkConfig defines the benchmark regression config for a given benchmark.
- source: defines the source of the benchmark data we want to query
- policy: defines the policy for the benchmark regressions, including frequency to
generate the report, range of the baseline and new values, and regression thresholds
for metrics
- name: the name of the benchmark
- id: the id of the benchmark, this must be unique for each benchmark, and cannot be changed once set
"""

name: str
id: str
source: BenchmarkApiSource
policy: Policy


@dataclass
class BenchmarkRegressionConfigBook:
configs: Dict[str, BenchmarkConfig] = field(default_factory=dict)

def __getitem__(self, key: str) -> BenchmarkConfig:
config = self.configs.get(key)
if not config:
raise KeyError(f"Config {key} not found")
return config
Loading