diff --git a/BackendBench/backends/llm.py b/BackendBench/backends/llm.py index 6e60eb4..f9ee6db 100644 --- a/BackendBench/backends/llm.py +++ b/BackendBench/backends/llm.py @@ -14,6 +14,7 @@ import torch +from BackendBench.eval import eval_performance from BackendBench.llm_client import LLMKernelGenerator from BackendBench.multiprocessing_eval import MultiprocessingEvaluator from BackendBench.utils import ( @@ -257,6 +258,34 @@ def test_kernel_correctness( feedback_info["summary"] = "Compilation failed" return False, feedback_info + def test_kernel_performance( + self, op, kernel_code: str, performance_tests: List, attempt: int = 1 + ) -> tuple[float, List]: + """Test kernel performance return performance score with results.""" + + op_str = str(op) + op_name = extract_operator_name(op_str) + kernel_file = self._generate_kernel_file_path(op_name, attempt) + + # Use compile_kernel_from_string for consistent loading + module_name = f"{op_name}_implementation_v{attempt}" + try: + kernel_impl = compile_kernel_from_string( + kernel_code=kernel_code, + op_name=op_name, + kernel_file_path=kernel_file, + expected_fn_name=op_name, + module_name=module_name, + ) + performance_score, performance_results = eval_performance( + op, kernel_impl, performance_tests + ) + except Exception as e: + logger.error(f"Performance evaluation failed: {str(e)}") + performance_score, performance_results = 0.0, [] + + return performance_score, performance_results + def __getitem__(self, key): if key in self.compiled_kernels: return self.compiled_kernels[key] @@ -303,11 +332,21 @@ def generate_kernels(self, suite, max_attempts=5): # Create feedback callback def feedback_callback(kernel_code: str, attempt: int) -> tuple[bool, Dict]: - # TODO: Add performance testing in addition to correctness testing - return self.test_kernel_correctness( + is_correct, feedback_info = self.test_kernel_correctness( op, kernel_code, op_test.correctness_tests, attempt ) + if is_correct: + perf_score, perf_results = self.test_kernel_performance( + op, kernel_code, op_test.performance_tests, attempt + ) + feedback_info["performance_score"] = perf_score + feedback_info["performance_results"] = perf_results + else: + feedback_info["performance_score"] = "N/A" + feedback_info["performance_results"] = [] + return is_correct, feedback_info + # Generate kernel with iterative refinement kernel_code, attempts_used, success = self.llm_client.generate_kernel_with_retry( op_name,