-
Notifications
You must be signed in to change notification settings - Fork 387
[fix] batch size in openai compatible endpoint #835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c5f7d1d
1f380d5
0e58637
5ee69f2
db8e198
f796133
3afee96
86490ff
1186ce5
2b9177c
1d90001
ab9fd5d
b7a6d3c
0bb1001
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,28 +1,19 @@ | ||||||||||||||||||||||
import base64 | ||||||||||||||||||||||
import json | ||||||||||||||||||||||
import os | ||||||||||||||||||||||
import time | ||||||||||||||||||||||
from io import BytesIO | ||||||||||||||||||||||
from typing import List, Tuple, Union | ||||||||||||||||||||||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||||||||||||||||||||||
from typing import List | ||||||||||||||||||||||
|
||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||
import requests as url_requests | ||||||||||||||||||||||
from accelerate import Accelerator, DistributedType | ||||||||||||||||||||||
from tqdm import tqdm | ||||||||||||||||||||||
|
||||||||||||||||||||||
from lmms_eval.api.instance import Instance | ||||||||||||||||||||||
from lmms_eval.api.model import lmms | ||||||||||||||||||||||
from lmms_eval.api.registry import register_model | ||||||||||||||||||||||
|
||||||||||||||||||||||
try: | ||||||||||||||||||||||
from decord import VideoReader, cpu | ||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||
pass | ||||||||||||||||||||||
|
||||||||||||||||||||||
from dotenv import find_dotenv, load_dotenv | ||||||||||||||||||||||
from dotenv import load_dotenv | ||||||||||||||||||||||
from loguru import logger as eval_logger | ||||||||||||||||||||||
from openai import AzureOpenAI, OpenAI | ||||||||||||||||||||||
from PIL import Image | ||||||||||||||||||||||
|
||||||||||||||||||||||
from lmms_eval.models.model_utils.gen_metrics import log_metrics | ||||||||||||||||||||||
from lmms_eval.models.simple.openai_compatible import ( | ||||||||||||||||||||||
|
@@ -39,89 +30,117 @@ class OpenAICompatible(OpenAICompatibleSimple): | |||||||||||||||||||||
|
||||||||||||||||||||||
def generate_until(self, requests) -> List[str]: | ||||||||||||||||||||||
res = [] | ||||||||||||||||||||||
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") | ||||||||||||||||||||||
|
||||||||||||||||||||||
batch_size = getattr(self, "batch_size_per_gpu", 1) | ||||||||||||||||||||||
batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)] | ||||||||||||||||||||||
pbar = tqdm(total=len(batched_requests), disable=(self.rank != 0), desc="Model Responding") | ||||||||||||||||||||||
|
||||||||||||||||||||||
e2e_latency = 0 | ||||||||||||||||||||||
total_tokens = 0 | ||||||||||||||||||||||
for ctx, doc_to_messages, gen_kwargs, doc_id, task, split in [reg.args for reg in requests]: | ||||||||||||||||||||||
if self.continual_mode is True and self.cache_mode == "resume": | ||||||||||||||||||||||
doc_uuid = f"{task}___{split}___{doc_id}" | ||||||||||||||||||||||
if doc_uuid in self.response_cache: | ||||||||||||||||||||||
response_text = self.response_cache[doc_uuid] | ||||||||||||||||||||||
if response_text: | ||||||||||||||||||||||
res.append(response_text) | ||||||||||||||||||||||
pbar.update(1) | ||||||||||||||||||||||
continue | ||||||||||||||||||||||
|
||||||||||||||||||||||
chat_messages = doc_to_messages(self.task_dict[task][split][doc_id]) | ||||||||||||||||||||||
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages}) | ||||||||||||||||||||||
|
||||||||||||||||||||||
payload = {"messages": chat_messages.to_openai_messages()} | ||||||||||||||||||||||
payload["model"] = self.model_version | ||||||||||||||||||||||
|
||||||||||||||||||||||
if "max_new_tokens" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["max_new_tokens"] = 1024 | ||||||||||||||||||||||
if gen_kwargs["max_new_tokens"] > 4096: | ||||||||||||||||||||||
gen_kwargs["max_new_tokens"] = 4096 | ||||||||||||||||||||||
if "temperature" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["temperature"] = 0 | ||||||||||||||||||||||
if "top_p" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["top_p"] = None | ||||||||||||||||||||||
if "num_beams" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["num_beams"] = 1 | ||||||||||||||||||||||
|
||||||||||||||||||||||
# payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"] | ||||||||||||||||||||||
payload["max_tokens"] = gen_kwargs["max_new_tokens"] | ||||||||||||||||||||||
payload["temperature"] = gen_kwargs["temperature"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version: | ||||||||||||||||||||||
# del payload["max_output_tokens"] | ||||||||||||||||||||||
del payload["temperature"] | ||||||||||||||||||||||
payload.pop("max_tokens") | ||||||||||||||||||||||
payload["reasoning_effort"] = "medium" | ||||||||||||||||||||||
payload["response_format"] = {"type": "text"} | ||||||||||||||||||||||
payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
for attempt in range(self.max_retries): | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
start_time = time.time() | ||||||||||||||||||||||
response = self.client.chat.completions.create(**payload) | ||||||||||||||||||||||
end_time = time.time() | ||||||||||||||||||||||
|
||||||||||||||||||||||
response_text = response.choices[0].message.content | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Calculate timing metrics | ||||||||||||||||||||||
e2e_latency += end_time - start_time | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Get token counts from response if available | ||||||||||||||||||||||
if hasattr(response, "usage"): | ||||||||||||||||||||||
total_tokens += response.usage.completion_tokens | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
# Approximate token count if not provided | ||||||||||||||||||||||
total_tokens += len(response_text.split()) | ||||||||||||||||||||||
|
||||||||||||||||||||||
break # If successful, break out of the loop | ||||||||||||||||||||||
|
||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||
error_msg = str(e) | ||||||||||||||||||||||
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
# On last attempt, log error and set empty response | ||||||||||||||||||||||
if attempt == self.max_retries - 1: | ||||||||||||||||||||||
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}") | ||||||||||||||||||||||
response_text = "" | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
time.sleep(self.timeout) | ||||||||||||||||||||||
|
||||||||||||||||||||||
res.append(response_text) | ||||||||||||||||||||||
pbar.update(1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.continual_mode is True: # Cache the response | ||||||||||||||||||||||
for batch_requests in batched_requests: | ||||||||||||||||||||||
batch_payloads = [] | ||||||||||||||||||||||
batch_doc_uuids = [] | ||||||||||||||||||||||
batch_responses = [] | ||||||||||||||||||||||
|
||||||||||||||||||||||
for req in batch_requests: | ||||||||||||||||||||||
ctx, doc_to_messages, gen_kwargs, doc_id, task, split = req.args | ||||||||||||||||||||||
doc_uuid = f"{task}___{split}___{doc_id}" | ||||||||||||||||||||||
self.response_cache[doc_uuid] = response_text | ||||||||||||||||||||||
batch_doc_uuids.append(doc_uuid) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.continual_mode is True and self.cache_mode == "resume": | ||||||||||||||||||||||
if doc_uuid in self.response_cache: | ||||||||||||||||||||||
response_text = self.response_cache[doc_uuid] | ||||||||||||||||||||||
if response_text: | ||||||||||||||||||||||
batch_responses.append(response_text) | ||||||||||||||||||||||
continue | ||||||||||||||||||||||
|
||||||||||||||||||||||
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id]) | ||||||||||||||||||||||
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw}) | ||||||||||||||||||||||
|
||||||||||||||||||||||
payload = {"messages": chat_messages.to_openai_messages()} | ||||||||||||||||||||||
payload["model"] = self.model_version | ||||||||||||||||||||||
|
||||||||||||||||||||||
if "max_new_tokens" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["max_new_tokens"] = 1024 | ||||||||||||||||||||||
if gen_kwargs["max_new_tokens"] > 4096: | ||||||||||||||||||||||
gen_kwargs["max_new_tokens"] = 4096 | ||||||||||||||||||||||
if "temperature" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["temperature"] = 0 | ||||||||||||||||||||||
if "top_p" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["top_p"] = None | ||||||||||||||||||||||
if "num_beams" not in gen_kwargs: | ||||||||||||||||||||||
gen_kwargs["num_beams"] = 1 | ||||||||||||||||||||||
|
||||||||||||||||||||||
payload["max_tokens"] = gen_kwargs["max_new_tokens"] | ||||||||||||||||||||||
payload["temperature"] = gen_kwargs["temperature"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version: | ||||||||||||||||||||||
del payload["temperature"] | ||||||||||||||||||||||
payload.pop("max_tokens") | ||||||||||||||||||||||
payload["reasoning_effort"] = "medium" | ||||||||||||||||||||||
payload["response_format"] = {"type": "text"} | ||||||||||||||||||||||
payload["max_completion_tokens"] = gen_kwargs["max_new_tokens"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
batch_payloads.append(payload) | ||||||||||||||||||||||
batch_responses.append(None) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def process_single_request(payload, i): | ||||||||||||||||||||||
if batch_responses[i] is not None: | ||||||||||||||||||||||
return batch_responses[i], i, 0, 0 | ||||||||||||||||||||||
|
||||||||||||||||||||||
for attempt in range(self.max_retries): | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
start_time = time.time() | ||||||||||||||||||||||
response = self.client.chat.completions.create(**payload) | ||||||||||||||||||||||
end_time = time.time() | ||||||||||||||||||||||
|
||||||||||||||||||||||
response_text = response.choices[0].message.content | ||||||||||||||||||||||
latency = end_time - start_time | ||||||||||||||||||||||
Comment on lines
+95
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add per-call timeout to OpenAI request Prevent indefinite hangs by passing - response = self.client.chat.completions.create(**payload)
+ response = self.client.chat.completions.create(**payload, timeout=self.timeout) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||
|
||||||||||||||||||||||
tokens = 0 | ||||||||||||||||||||||
if hasattr(response, "usage"): | ||||||||||||||||||||||
tokens = response.usage.completion_tokens | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
tokens = len(response_text.split()) | ||||||||||||||||||||||
|
||||||||||||||||||||||
return response_text, i, latency, tokens | ||||||||||||||||||||||
|
||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||
error_msg = str(e) | ||||||||||||||||||||||
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
if attempt == self.max_retries - 1: | ||||||||||||||||||||||
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}") | ||||||||||||||||||||||
return "", i, 0, 0 | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
time.sleep(self.timeout) | ||||||||||||||||||||||
|
||||||||||||||||||||||
return "", i, 0, 0 | ||||||||||||||||||||||
|
||||||||||||||||||||||
tasks_to_run = [(payload, i) for i, payload in enumerate(batch_payloads) if batch_responses[i] is None] | ||||||||||||||||||||||
|
||||||||||||||||||||||
if tasks_to_run: | ||||||||||||||||||||||
max_workers = min(len(tasks_to_run), 32) | ||||||||||||||||||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor: | ||||||||||||||||||||||
future_to_index = {executor.submit(process_single_request, payload, i): i for payload, i in tasks_to_run} | ||||||||||||||||||||||
|
||||||||||||||||||||||
for future in as_completed(future_to_index): | ||||||||||||||||||||||
response_text, i, latency, tokens = future.result() | ||||||||||||||||||||||
batch_responses[i] = response_text | ||||||||||||||||||||||
e2e_latency += latency | ||||||||||||||||||||||
total_tokens += tokens | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.continual_mode is True: | ||||||||||||||||||||||
for doc_uuid, response_text in zip(batch_doc_uuids, batch_responses): | ||||||||||||||||||||||
if response_text is not None: | ||||||||||||||||||||||
self.response_cache[doc_uuid] = response_text | ||||||||||||||||||||||
with open(self.response_persistent_file, "w") as f: | ||||||||||||||||||||||
json.dump(self.response_cache, f) | ||||||||||||||||||||||
|
||||||||||||||||||||||
res.extend([r for r in batch_responses if r is not None]) | ||||||||||||||||||||||
pbar.update(1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Calculate average speed | ||||||||||||||||||||||
avg_speed = total_tokens / e2e_latency if e2e_latency > 0 else 0 | ||||||||||||||||||||||
# Log metrics | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing mismatch corrupts response alignment within batch
Same misalignment as the simple variant: cached responses are appended while payloads skip them, then indices from
enumerate(batch_payloads)
are used to updatebatch_responses
. This leads to wrong pairings.Also applies to: 121-133
🧰 Tools
🪛 Ruff (0.13.1)
47-47: Unpacked variable
ctx
is never usedPrefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents