Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions BackendBench/backends/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from BackendBench.agent_errors import AgentError
from BackendBench.eval import (
CorrectnessTestResult,
eval_performance,
Expand Down Expand Up @@ -306,6 +307,11 @@ def test_kernel_correctness(
feedback_info.kernel_code = kernel_code

try:
# Only raise AgentError if kernel_code is missing or malformed
if not kernel_code or not isinstance(kernel_code, str):
raise AgentError(
"Kernel code is empty or not a string (agent failed to produce a kernel)."
)
kernel_file = self._generate_kernel_file_path(op_name, attempt)
if not os.path.exists(kernel_file):
save_kernel_to_file(kernel_code, kernel_file)
Expand All @@ -314,16 +320,12 @@ def test_kernel_correctness(
f"{op_name}_implementation_v{attempt}", kernel_file
)
module = importlib.util.module_from_spec(spec)

# Add to sys.modules so triton can find it
sys.modules[f"{op_name}_implementation_v{attempt}"] = module

try:
spec.loader.exec_module(module)

expected_name = f"{op_name}_kernel_impl"
if hasattr(module, expected_name):
# check if the kernel compile / is loadable
_ = getattr(module, expected_name)
else:
available_functions = [
Expand All @@ -334,12 +336,9 @@ def test_kernel_correctness(
raise ValueError(
f"Expected function '{expected_name}' not found. Available: {available_functions}"
)

finally:
if f"test_kernel_{op_name}_{attempt}" in sys.modules:
del sys.modules[f"test_kernel_{op_name}_{attempt}"]

# Clear CUDA cache and synchronize to prevent memory buildup
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
Expand All @@ -353,10 +352,7 @@ def test_kernel_correctness(
test_cases,
[],
)

# Start evaluation
evaluator.start_evaluation()
# Get results
results = evaluator.get_results()

for result in results:
Expand All @@ -370,6 +366,10 @@ def test_kernel_correctness(

return is_correct, feedback_info

except AgentError as e:
feedback_info["agent_error"] = str(e)
feedback_info["summary"] = f"Agent error: {str(e)}"
return False, feedback_info
except Exception as e:
logger.error(" ✗ Compilation failed:")
logger.error(f" Error: {str(e)}")
Expand Down
16 changes: 16 additions & 0 deletions BackendBench/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


class AgentError(Exception):
"""
Exception raised for errors related to LLM/agent failures,
such as rate limits, empty code, bad formatting, or API issues.
"""

def __init__(self, message: str):
super().__init__(message)
self.message = message
88 changes: 50 additions & 38 deletions BackendBench/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import anthropic
import requests
from requests.exceptions import ConnectionError
from tenacity import retry
from tenacity.wait import wait_random_exponential

from BackendBench.errors import AgentError

from .kernel_templates import KernelTemplateManager


Expand Down Expand Up @@ -60,15 +63,26 @@ def readme_setup_section(self) -> str:

@retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2))
def call_llm(self, prompt: str) -> str:
response = self.client.messages.create(
model=self.model,
max_tokens=8000,
temperature=0.2,
timeout=120.0,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text
return content
try:
response = self.client.messages.create(
model=self.model,
max_tokens=8000,
temperature=0.2,
timeout=120.0,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text
if not content:
raise ConnectionError(
"API error: Empty response from LLM API (possible rate limit or outage)."
)
if "rate limit" in content.lower():
raise ConnectionError("API error: Rate limit encountered from LLM API.")
return content
except anthropic.AnthropicError as e:
raise ConnectionError(f"API error: Anthropic API error: {e}")
except Exception as e:
raise ConnectionError(f"API error: Unexpected error: {e}")

def generate_kernel(
self,
Expand All @@ -93,9 +107,7 @@ def generate_kernel(

try:
content = self.call_llm(prompt)
if not content:
raise RuntimeError("Empty response from LLM relay server")

# Only raise AgentError if kernel extraction fails
extracted_code = self._extract_code_from_response(content)

print("\n=== DEBUG: RAW LLM RELAY RESPONSE ===")
Expand All @@ -106,25 +118,18 @@ def generate_kernel(

return extracted_code

except requests.exceptions.RequestException as e:
raise RuntimeError(
f"Failed to communicate with LLM relay server for {op_name}: {str(e)}"
)
except AgentError:
raise
except Exception as e:
raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}")
raise AgentError(f"Agent error: Failed to generate kernel for {op_name}: {str(e)}")

def _extract_code_from_response(self, response: str) -> str:
if "```python" not in response:
raise ValueError(
"No Python code block found in LLM response. Response should contain ```python...``` block."
)

raise AgentError("Agent error: No Python code block found in LLM response.")
start = response.find("```python") + len("```python")
end = response.find("```", start)

if end == -1:
raise ValueError("Unclosed Python code block in LLM response.")

raise AgentError("Agent error: Unclosed Python code block in LLM response.")
return response[start:end].strip()


Expand Down Expand Up @@ -180,17 +185,24 @@ def call_llm(self, prompt: str) -> str:
else None
)

response = requests.post(
self.server_url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=120.0,
proxies=proxies,
)

if response.status_code != 200:
raise RuntimeError(f"Server returned status {response.status_code}: {response.text}")

response_data = response.json()
content = response_data.get("output", "")
return content
try:
response = requests.post(
self.server_url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=120.0,
proxies=proxies,
)
if response.status_code != 200:
raise AgentError(
f"Agent error: Server returned status {response.status_code}: {response.text}"
)
response_data = response.json()
content = response_data.get("output", "")
if not content or "rate limit" in content.lower():
raise AgentError("Agent error: Empty response or rate limit encountered.")
return content
except requests.exceptions.RequestException as e:
raise AgentError(f"Agent error: Failed to communicate with LLM relay server: {str(e)}")
except Exception as e:
raise AgentError(f"Agent error: Unexpected error in LLM relay call: {e}")