Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.zip
deployment/
venv/
20 changes: 20 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
all: run-local

clean:
rm -rf deployment
rm -rf venv
rm -rf deployment.zip

venv/bin/python:
virtualenv venv
venv/bin/pip install -r requirements.txt

deployment.zip:
mkdir -p deployment
cp lambda_function.py lib ./deployment/.

pip3.10 install -r requirements.txt -t ./deployment/. --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.10 --upgrade
cd ./deployment && zip -q -r ../deployment.zip .

.PHONY: create-deployment-package
create-deployment-package: deployment.zip
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
import requests


@dataclass
class TimeRange:
start: str
end: str


@dataclass
class BenchmarkTimeSeriesItem:
group_info: Dict[str, Any]
num_of_dp: int
data: List[Dict[str, Any]] = field(default_factory=list)


@dataclass
class BenchmarkTimeSeriesApiData:
time_series: List[BenchmarkTimeSeriesItem]
time_range: TimeRange


@dataclass
class BenchmarkTimeSeriesApiResponse:
data: BenchmarkTimeSeriesApiData

@classmethod
def from_request(
cls, url: str, query: dict, timeout: int = 180
) -> "BenchmarkTimeSeriesApiResponse":
"""
Send a POST request and parse into BenchmarkTimeSeriesApiResponse.

Args:
url: API endpoint
query: JSON payload must
timeout: max seconds to wait for connect + response (default: 30)
Returns:
ApiResponse
Raises:
requests.exceptions.RequestException if network/timeout/HTTP error
RuntimeError if the API returns an "error" field or malformed data
"""
resp = requests.post(url, json=query, timeout=timeout)
resp.raise_for_status()
payload = resp.json()

if "error" in payload:
raise RuntimeError(f"API error: {payload['error']}")
try:
tr = TimeRange(**payload["data"]["time_range"])
ts = [
BenchmarkTimeSeriesItem(**item)
for item in payload["data"]["time_series"]
]
except Exception as e:
raise RuntimeError(f"Malformed API payload: {e}")
return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr))
79 changes: 79 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/common/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from common.config_model import (
BenchmarkApiSource,
BenchmarkConfig,
BenchmarkRegressionConfigBook,
DayRangeWindow,
Frequency,
RegressionPolicy,
Policy,
RangeConfig,
)

# Compiler benchmark regression config
# todo(elainewy): eventually each team should configure their own benchmark regression config, currenlty place here for lambda
COMPILER_BENCHMARK_CONFIG = BenchmarkConfig(
name="Compiler Benchmark Regression",
id="compiler_regression",
source=BenchmarkApiSource(
api_query_url="http://localhost:3000/api/benchmark/get_time_series",
type="benchmark_time_series_api",
# currently we only detect the regression for h100 with dtype bfloat16, and mode inference
# we can extend this to other devices, dtypes and mode in the future
api_endpoint_params_template="""
{
"name": "compiler_precompute",
"query_params": {
"commits": [],
"compilers": [],
"arch": "h100",
"device": "cuda",
"dtype": "bfloat16",
"granularity": "hour",
"mode": "inference",
"startTime": "{{ startTime }}",
"stopTime": "{{ stopTime }}",
"suites": ["torchbench", "huggingface", "timm_models"],
"workflowId": 0,
"branches": ["main"]
}
}
""",
),
# set baseline from past 7 days using avg, and compare with the last 1 day
policy=Policy(
frequency=Frequency(value=1, unit="days"),
range=RangeConfig(
baseline=DayRangeWindow(value=7),
comparison=DayRangeWindow(value=1),
),
metrics={
"passrate": RegressionPolicy(
name="passrate", condition="greater_equal", threshold=0.9, baseline_aggregation="max",
),
"geomean": RegressionPolicy(
name="geomean", condition="greater_equal", threshold=0.95,baseline_aggregation="max",
),
"compression_ratio": RegressionPolicy(
name="compression_ratio", condition="greater_equal", threshold=0.9, baseline_aggregation="max",
),
},
notification_config={
"type": "github",
"repo": "pytorch/test-infra",
"issue": "7081",
},
),
)

BENCHMARK_REGRESSION_CONFIG = BenchmarkRegressionConfigBook(
configs={
"compiler_regression": COMPILER_BENCHMARK_CONFIG,
}
)


def get_benchmark_regression_config(config_id: str) -> BenchmarkConfig:
try:
return BENCHMARK_REGRESSION_CONFIG[config_id]
except KeyError:
raise ValueError(f"Invalid config id: {config_id}")
224 changes: 224 additions & 0 deletions aws/lambda/benchmark_regression_summary_report/common/config_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from __future__ import annotations
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Dict, Literal, Optional, Set, Type, Union
from datetime import datetime, timedelta
from jinja2 import Environment, Template, meta
import requests
import json


# -------- Frequency --------
@dataclass(frozen=True)
class Frequency:
"""
The frequency of how often the report should be generated.
The minimum frequency we support is 1 day.
Attributes:
value: Number of units (e.g., 7 for 7 days).
unit: Unit of time, either "days" or "weeks".

Methods:
to_timedelta: Convert frequency into a datetime.timedelta.
get_text: return the frequency in text format
"""
value: int
unit: Literal["days", "weeks"]
def to_timedelta(self) -> timedelta:
"""Convert frequency N days or M weeks into a datetime.timedelta."""
if self.unit == "days":
return timedelta(days=self.value)
elif self.unit == "weeks":
return timedelta(weeks=self.value)
else:
raise ValueError(f"Unsupported unit: {self.unit}")

def get_text(self):
return f"{self.value} {self.unit}"


# -------- Source --------
_JINJA_ENV = Environment(autoescape=False)

@dataclass
class BenchmarkApiSource:
"""
Defines the source of the benchmark data we want to query
api_query_url: the url of the api to query
api_endpoint_params_template: the jinjia2 template of the api endpoint's query params
default_ctx: the default context to use when rendering the api_endpoint_params_template
"""
api_query_url: str
api_endpoint_params_template: str
type: Literal["benchmark_time_series_api", "other"] = "benchmark_time_series_api"
default_ctx: Dict[str, Any] = field(default_factory=dict)

def required_template_vars(self) -> set[str]:
ast = _JINJA_ENV.parse(self.api_endpoint_params_template)
return set(meta.find_undeclared_variables(ast))

def render(self, ctx: Dict[str, Any], strict: bool = True) -> dict:
"""Render with caller-supplied context (no special casing for start/end)."""
merged = {**self.default_ctx, **ctx}

if strict:
required = self.required_template_vars()
missing = required - merged.keys()
if missing:
raise ValueError(f"Missing required vars: {missing}")
rendered = Template(self.api_endpoint_params_template).render(**merged)
return json.loads(rendered)


# -------- Policy: range windows --------
@dataclass
class DayRangeWindow:
value: int
# raw indicates fetch from the source data
source: Literal["raw"] = "raw"

@dataclass
class RangeConfig:
"""
Defines the range of baseline and comparison windows for a given policy.
- baseline: the baseline window that build the baseline value
- comparison: the comparison window that we fetch data from to compare against the baseline value
"""
baseline: DayRangeWindow
comparison: DayRangeWindow

def total_timedelta(self) -> timedelta:
return timedelta(days=self.baseline.value + self.comparison.value)
def comparison_timedelta(self) -> timedelta:
return timedelta(days=self.comparison.value)
def baseline_timedelta(self) -> timedelta:
return timedelta(days=self.baseline.value)

# -------- Policy: metrics --------
@dataclass
class RegressionPolicy:
"""
Defines the policy for a given metric.
- new value muset be {x} baseline value:
- "greater_than": higher is better; new value must be strictly greater to baseline
- "less_than": lower is better; new value must be strictly lower to baseline
- "equal_to": new value should be ~= baseline * threshold within rel_tol
- "greater_equal": higher is better; new value must be greater or equal to baseline
- "less_equal": lower is better; new value must be less or equal to baseline
"""
name: str
condition: Literal["greater_than", "less_than", "equal_to","greater_equal","less_equal"]
threshold: float
baseline_aggregation: Literal["avg", "max", "min", "p50", "p90", "p95","latest","earliest"] = "max"
rel_tol: float = 1e-3 # used only for "equal_to"

def is_violation(self, value: float, baseline: float) -> bool:
target = baseline * self.threshold

if self.condition == "greater_than":
# value must be strictly greater than target
return value <= target

if self.condition == "greater_equal":
# value must be greater or equal to target
return value < target

if self.condition == "less_than":
# value must be strictly less than target
return value >= target

if self.condition == "less_equal":
# value must be less or equal to target
return value > target

if self.condition == "equal_to":
# |value - target| should be within rel_tol * max(1, |target|)
denom = max(1.0, abs(target))
return abs(value - target) > self.rel_tol * denom

raise ValueError(f"Unknown condition: {self.condition}")
class BaseNotificationConfig:
# every subclass must override this
type_tag: ClassVar[str]

@classmethod
def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
# pick only known fields for this dataclass
kwargs = {f.name: d.get(f.name) for f in fields(cls)}
return cls(**kwargs) # type: ignore

@classmethod
def matches(cls, d: Dict[str, Any]) -> bool:
return d.get("type") == cls.type_tag


@dataclass
class GitHubNotificationConfig(BaseNotificationConfig):
type: str = "github"
repo: str = ""
issue_number: str = ""
type_tag: ClassVar[str] = "github"

def create_github_comment(self, body: str, github_token: str) -> Dict[str, Any]:
"""
Create a new comment on a GitHub issue.
Args:
notification_config: dict with keys:
- type: must be "github"
- repo: "owner/repo"
- issue: issue number (string or int)
body: text of the comment
token: GitHub personal access token or GitHub Actions token

Returns:
The GitHub API response as a dict (JSON).
"""
url = f"https://api.github.com/repos/{self.repo}/issues/{self.issue_number}/comments"
headers = {
"Authorization": f"token {github_token}",
"Accept": "application/vnd.github+json",
"User-Agent": "bench-reporter/1.0",
}
resp = requests.post(url, headers=headers, json={"body": body})
resp.raise_for_status()
return resp.json()

@dataclass
class Policy:
frequency: Frequency
range: RangeConfig
metrics: Dict[str, RegressionPolicy]
notification_config: Optional[Dict[str, Any]] = None

def get_github_notification_config(self) -> Optional[GitHubNotificationConfig]:
if not self.notification_config:
return None
return notification_from_dict(self.notification_config) # type: ignore


# -------- Top-level benchmark regression config --------
@dataclass
class BenchmarkConfig:
"""
Represents a single benchmark regression configuration.

- BenchmarkConfig defines the benchmark regression config for a given benchmark.
- source: defines the source of the benchmark data we want to query
- policy: defines the policy for the benchmark regressions
- name: the name of the benchmark
- id: the id of the benchmark, this must be unique for each benchmark, and cannot be changed once set
"""
name: str
id: str
source: BenchmarkApiSource
policy: Policy


@dataclass
class BenchmarkRegressionConfigBook:
configs: Dict[str, BenchmarkConfig] = field(default_factory=dict)

def __getitem__(self, key: str) -> BenchmarkConfig:
config = self.configs.get(key, None)
if not config:
raise KeyError(f"Config {key} not found")
return config
Loading
Loading