Skip to content
43 changes: 41 additions & 2 deletions BackendBench/backends/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch

from BackendBench.eval import eval_performance
from BackendBench.llm_client import LLMKernelGenerator
from BackendBench.multiprocessing_eval import MultiprocessingEvaluator
from BackendBench.utils import (
Expand Down Expand Up @@ -257,6 +258,34 @@ def test_kernel_correctness(
feedback_info["summary"] = "Compilation failed"
return False, feedback_info

def test_kernel_performance(
self, op, kernel_code: str, performance_tests: List, attempt: int = 1
) -> tuple[float, List]:
"""Test kernel performance return performance score with results."""

op_str = str(op)
op_name = extract_operator_name(op_str)
kernel_file = self._generate_kernel_file_path(op_name, attempt)

# Use compile_kernel_from_string for consistent loading
module_name = f"{op_name}_implementation_v{attempt}"
try:
kernel_impl = compile_kernel_from_string(
kernel_code=kernel_code,
op_name=op_name,
kernel_file_path=kernel_file,
expected_fn_name=op_name,
module_name=module_name,
)
performance_score, performance_results = eval_performance(
op, kernel_impl, performance_tests
)
except Exception as e:
logger.error(f"Performance evaluation failed: {str(e)}")
performance_score, performance_results = 0.0, []

return performance_score, performance_results

def __getitem__(self, key):
if key in self.compiled_kernels:
return self.compiled_kernels[key]
Expand Down Expand Up @@ -303,11 +332,21 @@ def generate_kernels(self, suite, max_attempts=5):

# Create feedback callback
def feedback_callback(kernel_code: str, attempt: int) -> tuple[bool, Dict]:
# TODO: Add performance testing in addition to correctness testing
return self.test_kernel_correctness(
is_correct, feedback_info = self.test_kernel_correctness(
op, kernel_code, op_test.correctness_tests, attempt
)

if is_correct:
perf_score, perf_results = self.test_kernel_performance(
op, kernel_code, op_test.performance_tests, attempt
)
feedback_info["performance_score"] = perf_score
feedback_info["performance_results"] = perf_results
else:
feedback_info["performance_score"] = "N/A"
feedback_info["performance_results"] = []
return is_correct, feedback_info

# Generate kernel with iterative refinement
kernel_code, attempts_used, success = self.llm_client.generate_kernel_with_retry(
op_name,
Expand Down