diff --git a/BackendBench/backends/llm.py b/BackendBench/backends/llm.py index aba20de..2d98af1 100644 --- a/BackendBench/backends/llm.py +++ b/BackendBench/backends/llm.py @@ -16,6 +16,7 @@ import torch +from BackendBench.agent_errors import AgentError from BackendBench.eval import ( CorrectnessTestResult, eval_performance, @@ -306,6 +307,11 @@ def test_kernel_correctness( feedback_info.kernel_code = kernel_code try: + # Only raise AgentError if kernel_code is missing or malformed + if not kernel_code or not isinstance(kernel_code, str): + raise AgentError( + "Kernel code is empty or not a string (agent failed to produce a kernel)." + ) kernel_file = self._generate_kernel_file_path(op_name, attempt) if not os.path.exists(kernel_file): save_kernel_to_file(kernel_code, kernel_file) @@ -314,16 +320,12 @@ def test_kernel_correctness( f"{op_name}_implementation_v{attempt}", kernel_file ) module = importlib.util.module_from_spec(spec) - - # Add to sys.modules so triton can find it sys.modules[f"{op_name}_implementation_v{attempt}"] = module try: spec.loader.exec_module(module) - expected_name = f"{op_name}_kernel_impl" if hasattr(module, expected_name): - # check if the kernel compile / is loadable _ = getattr(module, expected_name) else: available_functions = [ @@ -334,12 +336,9 @@ def test_kernel_correctness( raise ValueError( f"Expected function '{expected_name}' not found. Available: {available_functions}" ) - finally: if f"test_kernel_{op_name}_{attempt}" in sys.modules: del sys.modules[f"test_kernel_{op_name}_{attempt}"] - - # Clear CUDA cache and synchronize to prevent memory buildup if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() @@ -353,10 +352,7 @@ def test_kernel_correctness( test_cases, [], ) - - # Start evaluation evaluator.start_evaluation() - # Get results results = evaluator.get_results() for result in results: @@ -370,6 +366,10 @@ def test_kernel_correctness( return is_correct, feedback_info + except AgentError as e: + feedback_info["agent_error"] = str(e) + feedback_info["summary"] = f"Agent error: {str(e)}" + return False, feedback_info except Exception as e: logger.error(" ✗ Compilation failed:") logger.error(f" Error: {str(e)}") diff --git a/BackendBench/errors.py b/BackendBench/errors.py new file mode 100644 index 0000000..b217566 --- /dev/null +++ b/BackendBench/errors.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +class AgentError(Exception): + """ + Exception raised for errors related to LLM/agent failures, + such as rate limits, empty code, bad formatting, or API issues. + """ + + def __init__(self, message: str): + super().__init__(message) + self.message = message diff --git a/BackendBench/llm_client.py b/BackendBench/llm_client.py index 35e575c..c1cbca3 100644 --- a/BackendBench/llm_client.py +++ b/BackendBench/llm_client.py @@ -9,9 +9,12 @@ import anthropic import requests +from requests.exceptions import ConnectionError from tenacity import retry from tenacity.wait import wait_random_exponential +from BackendBench.errors import AgentError + from .kernel_templates import KernelTemplateManager @@ -60,15 +63,26 @@ def readme_setup_section(self) -> str: @retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2)) def call_llm(self, prompt: str) -> str: - response = self.client.messages.create( - model=self.model, - max_tokens=8000, - temperature=0.2, - timeout=120.0, - messages=[{"role": "user", "content": prompt}], - ) - content = response.content[0].text - return content + try: + response = self.client.messages.create( + model=self.model, + max_tokens=8000, + temperature=0.2, + timeout=120.0, + messages=[{"role": "user", "content": prompt}], + ) + content = response.content[0].text + if not content: + raise ConnectionError( + "API error: Empty response from LLM API (possible rate limit or outage)." + ) + if "rate limit" in content.lower(): + raise ConnectionError("API error: Rate limit encountered from LLM API.") + return content + except anthropic.AnthropicError as e: + raise ConnectionError(f"API error: Anthropic API error: {e}") + except Exception as e: + raise ConnectionError(f"API error: Unexpected error: {e}") def generate_kernel( self, @@ -93,9 +107,7 @@ def generate_kernel( try: content = self.call_llm(prompt) - if not content: - raise RuntimeError("Empty response from LLM relay server") - + # Only raise AgentError if kernel extraction fails extracted_code = self._extract_code_from_response(content) print("\n=== DEBUG: RAW LLM RELAY RESPONSE ===") @@ -106,25 +118,18 @@ def generate_kernel( return extracted_code - except requests.exceptions.RequestException as e: - raise RuntimeError( - f"Failed to communicate with LLM relay server for {op_name}: {str(e)}" - ) + except AgentError: + raise except Exception as e: - raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}") + raise AgentError(f"Agent error: Failed to generate kernel for {op_name}: {str(e)}") def _extract_code_from_response(self, response: str) -> str: if "```python" not in response: - raise ValueError( - "No Python code block found in LLM response. Response should contain ```python...``` block." - ) - + raise AgentError("Agent error: No Python code block found in LLM response.") start = response.find("```python") + len("```python") end = response.find("```", start) - if end == -1: - raise ValueError("Unclosed Python code block in LLM response.") - + raise AgentError("Agent error: Unclosed Python code block in LLM response.") return response[start:end].strip() @@ -180,17 +185,24 @@ def call_llm(self, prompt: str) -> str: else None ) - response = requests.post( - self.server_url, - json=request_data, - headers={"Content-Type": "application/json"}, - timeout=120.0, - proxies=proxies, - ) - - if response.status_code != 200: - raise RuntimeError(f"Server returned status {response.status_code}: {response.text}") - - response_data = response.json() - content = response_data.get("output", "") - return content + try: + response = requests.post( + self.server_url, + json=request_data, + headers={"Content-Type": "application/json"}, + timeout=120.0, + proxies=proxies, + ) + if response.status_code != 200: + raise AgentError( + f"Agent error: Server returned status {response.status_code}: {response.text}" + ) + response_data = response.json() + content = response_data.get("output", "") + if not content or "rate limit" in content.lower(): + raise AgentError("Agent error: Empty response or rate limit encountered.") + return content + except requests.exceptions.RequestException as e: + raise AgentError(f"Agent error: Failed to communicate with LLM relay server: {str(e)}") + except Exception as e: + raise AgentError(f"Agent error: Unexpected error in LLM relay call: {e}")