@@ -110,9 +110,7 @@ def generate_kernels(
110110 try :
111111 # Create problem description for the operator
112112 folder_name = op_name_to_folder_name (op_name )
113- problem_description = _create_problem_description_from_op (
114- op , op_name , folder_name
115- )
113+ problem_description = _create_problem_description_from_op (op , op_name )
116114
117115 # Create test code from BackendBench tests if provided
118116 test_code = None
@@ -133,6 +131,18 @@ def generate_kernels(
133131 if result ["success" ]:
134132 kernel_code = result ["kernel_code" ]
135133
134+ # Automatically fix function name to match BackendBench's expectations
135+ # Replace generic function names with the required name
136+ import re
137+
138+ expected_func_name = f"{ folder_name } _kernel_impl"
139+ kernel_code = re .sub (
140+ r"\bdef\s+(kernel_function)\s*\(" ,
141+ f"def { expected_func_name } (" ,
142+ kernel_code ,
143+ )
144+ logger .debug (f" Ensured function name is: { expected_func_name } " )
145+
136146 # Create operator directory (e.g., generated_kernels/abs__default/)
137147 folder_name = op_name_to_folder_name (op_name )
138148 op_dir = os .path .join (kernels_dir , folder_name )
@@ -217,27 +227,19 @@ def evaluate_kernels(
217227 return result .returncode
218228
219229
220- def _create_problem_description_from_op (op , op_name : str , folder_name : str ) -> str :
230+ def _create_problem_description_from_op (op , op_name : str ) -> str :
221231 """
222232 Create a problem description for KernelAgent based on the PyTorch operation.
223233
224234 Args:
225235 op: PyTorch operation
226236 op_name: Operation name extracted from op
227- folder_name: Folder name for the operator (e.g., abs__default)
228237
229238 Returns:
230239 Problem description string for KernelAgent
231240 """
232241 # Create a comprehensive problem description that KernelAgent can understand
233242 problem_description = f"""
234- CRITICAL REQUIREMENT - FUNCTION NAMING:
235- The main wrapper function MUST be named EXACTLY: { folder_name } _kernel_impl
236- This is MANDATORY. Do NOT use 'kernel_function' or any other name.
237- Example for this operator:
238- def { folder_name } _kernel_impl(*args, **kwargs):
239- # Your triton kernel implementation
240-
241243Task: Implement a high-performance Triton kernel for the PyTorch operation: { op_name }
242244
243245Requirements:
0 commit comments