Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ vllm:
max_retries: 3 # Number of retries for API calls
retry_delay: 1.0 # Initial delay between retries (seconds)
sleep_time: 0.1 # Small delay in seconds between batches to avoid rate limits
http_request_timeout: 180 # Http Request timeout in seconds (3 minutes)

# API endpoint configuration
api-endpoint:
Expand Down
178 changes: 125 additions & 53 deletions synthetic_data_kit/models/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import logging
import asyncio
import aiohttp
from pathlib import Path

from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider
Expand All @@ -36,7 +37,8 @@ def __init__(self,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None):
retry_delay: Optional[float] = None,
http_request_timeout: Optional[int] = None):
"""Initialize an LLM client that supports multiple providers

Args:
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(self,
self.max_retries = max_retries or vllm_config.get('max_retries')
self.retry_delay = retry_delay or vllm_config.get('retry_delay')
self.sleep_time = vllm_config.get('sleep_time',0.1)
self.http_request_timeout = vllm_config.get('http_request_timeout', 180)

# No client to initialize for vLLM as we use requests directly
# Verify server is running
Expand Down Expand Up @@ -304,7 +307,7 @@ def _vllm_chat_completion(self,
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
timeout=180 # Increased timeout to 180 seconds
timeout=self.http_request_timeout # made the http timeout dynamic
)

if verbose:
Expand Down Expand Up @@ -500,12 +503,6 @@ def _openai_batch_completion(self,
if verbose:
logger.info(f"Processing batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests")

# Import asyncio here to avoid issues if not available
try:
import asyncio
except ImportError:
raise ImportError("The 'asyncio' package is required for batch processing. Please ensure you're using Python 3.7+.")

# Define async batch processing function
async def process_batch():
tasks = []
Expand Down Expand Up @@ -534,13 +531,13 @@ async def process_batch():
return results

def _vllm_batch_completion(self,
message_batches: List[List[Dict[str, str]]],
temperature: float,
max_tokens: int,
top_p: float,
batch_size: int,
verbose: bool) -> List[str]:
"""Process multiple message sets in batches using vLLM's API"""
message_batches: List[List[Dict[str, str]]],
temperature: float,
max_tokens: int,
top_p: float,
batch_size: int,
verbose: bool) -> List[str]:
"""Process multiple message sets in true batches using vLLM's API with concurrent requests"""
results = []

# Process message batches in chunks to avoid overloading the server
Expand All @@ -549,49 +546,124 @@ def _vllm_batch_completion(self,
if verbose:
logger.info(f"Processing batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests")

# Create batch request payload for VLLM
batch_requests = []
for messages in batch_chunk:
batch_requests.append({
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p
})

try:
# For now, we run these in parallel with multiple requests
batch_results = []
for request_data in batch_requests:
# Only print if verbose mode is enabled
if verbose:
logger.info(f"Sending batch request to vLLM model {self.model}...")

response = requests.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
data=json.dumps(request_data),
timeout=180 # Increased timeout for batch processing
)

if verbose:
logger.info(f"Received response with status code: {response.status_code}")

response.raise_for_status()
content = response.json()["choices"][0]["message"]["content"]
batch_results.append(content)

results.extend(batch_results)

except (requests.exceptions.RequestException, KeyError, IndexError) as e:
raise Exception(f"Failed to process vLLM batch: {str(e)}")
# Run the async batch processing
batch_results = asyncio.run(self._process_vllm_batch_async(
batch_chunk, temperature, max_tokens, top_p, verbose, batch_size
))
results.extend(batch_results)

# Small delay between batches
# Small delay between batches to avoid rate limits
if i + batch_size < len(message_batches):
time.sleep(self.sleep_time)

return results

async def _process_vllm_batch_async(self,
batch_chunk: List[List[Dict[str, str]]],
temperature: float,
max_tokens: int,
top_p: float,
verbose: bool,
batch_size: int) -> List[str]:
"""Process a batch of requests asynchronously using aiohttp"""

async def process_single_request(session: aiohttp.ClientSession,
messages: List[Dict[str, str]],
semaphore: asyncio.Semaphore,
http_request_timeout: int,) -> str:
"""Process a single request with retry logic"""
data = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p
}

async with semaphore: # Limit concurrent requests
for attempt in range(self.max_retries):
try:
if verbose and attempt == 0: # Only log on first attempt
logger.info(f"Sending async request to vLLM model {self.model}...")

async with session.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
data=json.dumps(data),
timeout=aiohttp.ClientTimeout(total=http_request_timeout) # 300 minutes timeout
) as response:

if verbose and attempt == 0:
logger.info(f"Received response with status code: {response.status}")

response.raise_for_status()
response_json = await response.json()

try:
return response_json["choices"][0]["message"]["content"]
except (KeyError, IndexError) as e:
raise ValueError(f"Invalid response format: {e}")

except asyncio.TimeoutError:
error_msg = f"Request timeout on attempt {attempt + 1}/{self.max_retries}"
if verbose:
logger.warning(error_msg)
if attempt == self.max_retries - 1:
return f"ERROR: {error_msg}"

except aiohttp.ClientError as e:
error_msg = f"HTTP error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
if verbose:
logger.warning(error_msg)
if attempt == self.max_retries - 1:
return f"ERROR: {error_msg}"

except Exception as e:
error_msg = f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
if verbose:
logger.warning(error_msg)
if attempt == self.max_retries - 1:
return f"ERROR: {error_msg}"

# Exponential backoff between retries
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1))

# Create semaphore to limit concurrent connections (prevent overwhelming the server)
max_concurrent = min(batch_size, 1024) # Cap at 1024 concurrent requests
semaphore = asyncio.Semaphore(max_concurrent)

# Create aiohttp session with connection pooling
connector = aiohttp.TCPConnector(
limit=max_concurrent * 2, # Total connection pool size
limit_per_host=max_concurrent, # Connections per host
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
)

timeout = aiohttp.ClientTimeout(total=300) # 5 minutes total timeout

async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"Content-Type": "application/json"}
) as session:
# Create tasks for all requests in the batch
tasks = []
for messages in batch_chunk:
task = process_single_request(session, messages, semaphore, self.http_request_timeout)
tasks.append(task)

if verbose:
logger.info(f"Starting {len(tasks)} concurrent requests...")

# Execute all requests concurrently
results = await asyncio.gather(*tasks, return_exceptions=False)

if verbose:
logger.info(f"Completed {len(results)} requests")

return results

@classmethod
def from_config(cls, config_path: Path) -> 'LLMClient':
Expand Down