Skip to content

Commit ac7cd0f

Browse files
author
Kaiming Cheng
committed
fix function name
1 parent 53eed66 commit ac7cd0f

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

benchmark/BackendBench/eval.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def generate_kernels(
110110
try:
111111
# Create problem description for the operator
112112
folder_name = op_name_to_folder_name(op_name)
113-
problem_description = _create_problem_description_from_op(
114-
op, op_name, folder_name
115-
)
113+
problem_description = _create_problem_description_from_op(op, op_name)
116114

117115
# Create test code from BackendBench tests if provided
118116
test_code = None
@@ -133,6 +131,18 @@ def generate_kernels(
133131
if result["success"]:
134132
kernel_code = result["kernel_code"]
135133

134+
# Automatically fix function name to match BackendBench's expectations
135+
# Replace generic function names with the required name
136+
import re
137+
138+
expected_func_name = f"{folder_name}_kernel_impl"
139+
kernel_code = re.sub(
140+
r"\bdef\s+(kernel_function)\s*\(",
141+
f"def {expected_func_name}(",
142+
kernel_code,
143+
)
144+
logger.debug(f" Ensured function name is: {expected_func_name}")
145+
136146
# Create operator directory (e.g., generated_kernels/abs__default/)
137147
folder_name = op_name_to_folder_name(op_name)
138148
op_dir = os.path.join(kernels_dir, folder_name)
@@ -217,27 +227,19 @@ def evaluate_kernels(
217227
return result.returncode
218228

219229

220-
def _create_problem_description_from_op(op, op_name: str, folder_name: str) -> str:
230+
def _create_problem_description_from_op(op, op_name: str) -> str:
221231
"""
222232
Create a problem description for KernelAgent based on the PyTorch operation.
223233
224234
Args:
225235
op: PyTorch operation
226236
op_name: Operation name extracted from op
227-
folder_name: Folder name for the operator (e.g., abs__default)
228237
229238
Returns:
230239
Problem description string for KernelAgent
231240
"""
232241
# Create a comprehensive problem description that KernelAgent can understand
233242
problem_description = f"""
234-
CRITICAL REQUIREMENT - FUNCTION NAMING:
235-
The main wrapper function MUST be named EXACTLY: {folder_name}_kernel_impl
236-
This is MANDATORY. Do NOT use 'kernel_function' or any other name.
237-
Example for this operator:
238-
def {folder_name}_kernel_impl(*args, **kwargs):
239-
# Your triton kernel implementation
240-
241243
Task: Implement a high-performance Triton kernel for the PyTorch operation: {op_name}
242244
243245
Requirements:

0 commit comments

Comments
 (0)