Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions aidial_rag/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
RequestType,
get_configuration,
)
from aidial_rag.dial_api_client import create_dial_api_client
from aidial_rag.document_record import Chunk, DocumentRecord
from aidial_rag.documents import load_documents
from aidial_rag.index_record import ChunkMetadata, RetrievalType
Expand Down Expand Up @@ -248,7 +247,7 @@ async def chat_completion(
self, request: Request, response: Response
) -> None:
loop = asyncio.get_running_loop()
with create_request_context(
async with create_request_context(
self.app_config.dial_url, request, response
) as request_context:
choice = request_context.choice
Expand All @@ -274,21 +273,19 @@ async def chat_completion(
get_attachment_links(request_context, messages)
)

dial_api_client = await create_dial_api_client(request_context)
index_storage = self.index_storage_holder.get_storage(
dial_api_client
request_context.dial_api_client
)

indexing_tasks = create_indexing_tasks(
attachment_links,
dial_api_client,
request_context.dial_api_client,
)

indexing_results = await load_documents(
request_context,
indexing_tasks,
index_storage,
dial_api_client,
config=request_config,
)

Expand Down
3 changes: 2 additions & 1 deletion aidial_rag/attachment_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def to_dial_metadata_url(
if not request_context.is_dial_url(absolute_url):
return None

return urljoin(
absolute_metadata_url = urljoin(
request_context.dial_metadata_base_url, link, allow_fragments=True
)
return to_dial_relative_url(request_context, absolute_metadata_url)


class AttachmentLink(BaseModel):
Expand Down
61 changes: 29 additions & 32 deletions aidial_rag/dial_api_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import io
from contextlib import asynccontextmanager
from typing import AsyncGenerator

import aiohttp

from aidial_rag.request_context import RequestContext
from aidial_rag.dial_config import DialConfig


async def _get_bucket_id(dial_base_url, headers: dict) -> str:
async def _get_bucket_id(session: aiohttp.ClientSession, headers: dict) -> str:
relative_url = (
"bucket" # /v1/ is already included in the base url for the Dial API
)
async with aiohttp.ClientSession(base_url=dial_base_url) as session:
async with session.get(relative_url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
return data["bucket"]
async with session.get(relative_url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
return data["bucket"]


def _to_form_data(key: str, data: bytes, content_type: str) -> aiohttp.FormData:
Expand All @@ -25,39 +26,35 @@ def _to_form_data(key: str, data: bytes, content_type: str) -> aiohttp.FormData:


class DialApiClient:
def __init__(self, dial_api_base_url: str, headers: dict, bucket_id: str):
def __init__(self, client_session: aiohttp.ClientSession, bucket_id: str):
self._client_session = client_session
self.bucket_id = bucket_id

self._dial_api_base_url = dial_api_base_url
self._headers = headers
@property
def session(self) -> aiohttp.ClientSession:
return self._client_session

async def get_file(self, relative_url: str) -> bytes | None:
async with aiohttp.ClientSession(
base_url=self._dial_api_base_url
) as session:
async with session.get(
relative_url, headers=self._headers
) as response:
response.raise_for_status()
return await response.read()
async with self.session.get(relative_url) as response:
response.raise_for_status()
return await response.read()

async def put_file(
self, relative_url: str, data: bytes, content_type: str
) -> dict:
async with aiohttp.ClientSession(
base_url=self._dial_api_base_url
) as session:
form_data = _to_form_data(relative_url, data, content_type)
async with session.put(
relative_url, data=form_data, headers=self._headers
) as response:
response.raise_for_status()
return await response.json()
form_data = _to_form_data(relative_url, data, content_type)
async with self.session.put(relative_url, data=form_data) as response:
response.raise_for_status()
return await response.json()


@asynccontextmanager
async def create_dial_api_client(
request_context: RequestContext,
) -> DialApiClient:
headers = request_context.get_api_key_headers()
bucket_id = await _get_bucket_id(request_context.dial_base_url, headers)
return DialApiClient(request_context.dial_base_url, headers, bucket_id)
config: DialConfig,
) -> AsyncGenerator[DialApiClient, None]:
headers = {"api-key": config.api_key.get_secret_value()}
async with aiohttp.ClientSession(
base_url=config.dial_base_url, headers=headers
) as session:
bucket_id = await _get_bucket_id(session, headers)
yield DialApiClient(session, bucket_id)
4 changes: 4 additions & 0 deletions aidial_rag/dial_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
class DialConfig(BaseConfig):
dial_url: str
api_key: SecretStr

@property
def dial_base_url(self) -> str:
return f"{self.dial_url}/v1/"
20 changes: 8 additions & 12 deletions aidial_rag/dial_user_limits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import aiohttp
from pydantic import BaseModel, Field

from aidial_rag.dial_config import DialConfig
from aidial_rag.dial_api_client import DialApiClient


class TokenStats(BaseModel):
Expand All @@ -20,18 +19,15 @@ class UserLimitsForModel(BaseModel):


async def get_user_limits_for_model(
dial_config: DialConfig, deployment_name: str
dial_api_client: DialApiClient, deployment_name: str
) -> UserLimitsForModel:
"""Returns the user limits for the specified model deployment.

See https://epam-rail.com/dial_api#tag/Limits for the API documentation.
"""
headers = {"Api-Key": dial_config.api_key.get_secret_value()}
limits_url = (
f"{dial_config.dial_url}/v1/deployments/{deployment_name}/limits"
)
async with aiohttp.ClientSession() as session:
async with session.get(limits_url, headers=headers) as response:
response.raise_for_status()
limits_json = await response.json()
return UserLimitsForModel.model_validate(limits_json)

limits_relative_url = f"deployments/{deployment_name}/limits"
async with dial_api_client.session.get(limits_relative_url) as response:
response.raise_for_status()
limits_json = await response.json()
return UserLimitsForModel.model_validate(limits_json)
58 changes: 32 additions & 26 deletions aidial_rag/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from aidial_rag.attachment_link import AttachmentLink
from aidial_rag.base_config import BaseConfig, IndexRebuildTrigger
from aidial_rag.content_stream import SupportsWriteStr
from aidial_rag.dial_api_client import DialApiClient
from aidial_rag.errors import InvalidDocumentError
from aidial_rag.image_processor.extract_pages import (
are_image_pages_supported,
extract_number_of_pages,
)
from aidial_rag.print_stats import print_documents_stats
from aidial_rag.request_context import RequestContext
from aidial_rag.resources.cpu_pools import run_in_indexing_cpu_pool
from aidial_rag.utils import format_size, get_bytes_length, timed_block

Expand Down Expand Up @@ -85,18 +85,17 @@ class ParserConfig(BaseConfig):


async def download_attachment(
url, headers, download_config: HttpClientConfig
url: str, session: aiohttp.ClientSession, download_config: HttpClientConfig
) -> tuple[str, bytes]:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers=headers, timeout=download_config.get_client_timeout()
) as response:
response.raise_for_status()
content_type = response.headers.get("Content-Type", "")
async with session.get(
url, timeout=download_config.get_client_timeout()
) as response:
response.raise_for_status()
content_type = response.headers.get("Content-Type", "")

content = await response.read() # Await the coroutine
logging.debug(f"Downloaded {url}: {len(content)} bytes")
return content_type, content
content = await response.read()
logging.debug(f"Downloaded {url}: {len(content)} bytes")
return content_type, content


def add_source_metadata(
Expand All @@ -121,7 +120,7 @@ def add_pdf_source_metadata(


async def load_dial_document_metadata(
request_context: RequestContext,
dial_api_client: DialApiClient,
attachment_link: AttachmentLink,
config: HttpClientConfig,
) -> dict:
Expand All @@ -131,29 +130,36 @@ async def load_dial_document_metadata(
metadata_url = attachment_link.dial_metadata_url
assert metadata_url is not None

headers = request_context.get_file_access_headers(metadata_url)
async with aiohttp.ClientSession(
timeout=config.get_client_timeout()
) as session:
async with session.get(metadata_url, headers=headers) as response:
if not response.ok:
error_message = f"{response.status} {response.reason}"
raise InvalidDocumentError(error_message)
return await response.json()
async with dial_api_client.session.get(
metadata_url, timeout=config.get_client_timeout()
) as response:
if not response.ok:
error_message = f"{response.status} {response.reason}"
raise InvalidDocumentError(error_message)
return await response.json()


async def load_attachment(
dial_api_client: DialApiClient,
attachment_link: AttachmentLink,
headers: dict,
download_config: HttpClientConfig | None = None,
) -> tuple[str, str, bytes]:
if download_config is None:
download_config = HttpClientConfig()
absolute_url = attachment_link.absolute_url
file_name = attachment_link.display_name
content_type, attachment_bytes = await download_attachment(
absolute_url, headers, download_config
)

if attachment_link.is_dial_document:
content_type, attachment_bytes = await download_attachment(
attachment_link.dial_link, dial_api_client.session, download_config
)
else:
# Use separate session for non-Dial documents
# to avoid passing Dial headers to non-Dial servers
async with aiohttp.ClientSession() as session:
content_type, attachment_bytes = await download_attachment(
attachment_link.absolute_url, session, download_config
)

if attachment_bytes:
return file_name, content_type, attachment_bytes
raise InvalidDocumentError(
Expand Down
26 changes: 9 additions & 17 deletions aidial_rag/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ async def check_document_access(
) as access_stage:
try:
await load_dial_document_metadata(
request_context, attachment_link, config.check_access
request_context.dial_api_client,
attachment_link,
config.check_access,
)
except InvalidDocumentError as e:
access_stage.append_content(e.message)
Expand All @@ -102,6 +104,7 @@ def get_default_image_chunk(attachment_link: AttachmentLink):


async def load_document_impl(
dial_api_client: DialApiClient,
dial_config: DialConfig,
dial_limited_resources: DialLimitedResources,
attachment_link: AttachmentLink,
Expand All @@ -116,16 +119,9 @@ async def load_document_impl(
)
io_stream = MultiStream(MarkdownStream(stage_stream), logger_stream)

absolute_url = attachment_link.absolute_url
headers = (
{"api-key": dial_config.api_key.get_secret_value()}
if absolute_url.startswith(dial_config.dial_url)
else {}
)

file_name, content_type, original_doc_bytes = await load_attachment(
dial_api_client,
attachment_link,
headers,
download_config=config.download,
)
logger.debug(f"Successfully loaded document {file_name} of {content_type}")
Expand Down Expand Up @@ -235,10 +231,10 @@ async def load_document(
request_context: RequestContext,
task: IndexingTask,
index_storage: IndexStorage,
dial_api_client: DialApiClient,
config: RequestConfig,
) -> DocumentRecord:
attachment_link = task.attachment_link
dial_api_client = request_context.dial_api_client
with handle_document_processing_error(
attachment_link, config.log_document_links
):
Expand All @@ -247,7 +243,6 @@ async def load_document(

choice = request_context.choice

# TODO: Move check_document_access to the DialApiClient
await check_document_access(request_context, attachment_link, config)

doc_record = None
Expand All @@ -270,6 +265,7 @@ async def load_document(
io_stream = doc_stage.content_stream
try:
doc_record = await load_document_impl(
dial_api_client,
request_context.dial_config,
request_context.dial_limited_resources,
attachment_link,
Expand All @@ -295,12 +291,11 @@ async def load_document_task(
request_context: RequestContext,
task: IndexingTask,
index_storage: IndexStorage,
dial_api_client: DialApiClient,
config: RequestConfig,
) -> DocumentIndexingResult:
try:
doc_record = await load_document(
request_context, task, index_storage, dial_api_client, config
request_context, task, index_storage, config
)
return DocumentIndexingSuccess(
task=task,
Expand All @@ -318,16 +313,13 @@ async def load_documents(
request_context: RequestContext,
tasks: Iterable[IndexingTask],
index_storage: IndexStorage,
dial_api_client: DialApiClient,
config: RequestConfig,
) -> List[DocumentIndexingResult]:
# TODO: Rewrite this function using TaskGroup to cancel all tasks if one of them fails
# if ignore_document_loading_errors is not set in the config
return await asyncio.gather(
*[
load_document_task(
request_context, task, index_storage, dial_api_client, config
)
load_document_task(request_context, task, index_storage, config)
for task in tasks
],
)
Loading
Loading