From 06dc0b26a125f62dfd4e65a4737e85c259740202 Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 11:32:54 -0500 Subject: [PATCH 1/9] init --- .gitignore | 1 + backend/chainlit/types/__init__.py | 1 + backend/chainlit/types/step.py | 402 +++++++++++++++++++++++++++++ backend/pyproject.toml | 3 +- backend/uv.lock | 17 +- 5 files changed, 422 insertions(+), 2 deletions(-) create mode 100644 backend/chainlit/types/__init__.py create mode 100644 backend/chainlit/types/step.py diff --git a/.gitignore b/.gitignore index f2272e91eb..288d0b25ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ build dist *.egg-info +*.ipynb .env diff --git a/backend/chainlit/types/__init__.py b/backend/chainlit/types/__init__.py new file mode 100644 index 0000000000..065afca0e1 --- /dev/null +++ b/backend/chainlit/types/__init__.py @@ -0,0 +1 @@ +from .step import Step, StepService \ No newline at end of file diff --git a/backend/chainlit/types/step.py b/backend/chainlit/types/step.py new file mode 100644 index 0000000000..c696a2a8aa --- /dev/null +++ b/backend/chainlit/types/step.py @@ -0,0 +1,402 @@ + +import asyncio +import inspect +import json +import uuid +from copy import deepcopy +from functools import wraps +from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal + +from sqlmodel import SQLModel, Field + +# If you want to keep compatibility with literalai types, import as needed +from literalai import BaseGeneration +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel +from chainlit.config import config +from chainlit.context import CL_RUN_NAMES, context, local_steps +from chainlit.data import get_data_layer +from chainlit.element import Element +from chainlit.logger import logger +from chainlit.types import FeedbackDict +from chainlit.utils import utc_now + +TrueStepType = Literal[ + "run", "tool", "llm", "embedding", "retrieval", "rerank", "undefined" +] + +MessageStepType = Literal["user_message", "assistant_message", "system_message"] + +StepType = Union[TrueStepType, MessageStepType] + +class Step(SQLModel, table=True): + id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + name: str = "" + type: str = "undefined" + parent_id: Optional[str] = Field(default=None, foreign_key="step.id") + thread_id: Optional[str] = None + streaming: bool = False + persisted: bool = False + show_input: Union[bool, str] = "json" + is_error: Optional[bool] = False + metadata: Dict = Field(default_factory=dict) + tags: Optional[List[str]] = None + created_at: Optional[str] = None + start: Optional[str] = None + end: Optional[str] = None + generation: Optional[BaseGeneration] = None + language: Optional[str] = None + default_open: Optional[bool] = False + input: Optional[str] = "" + output: Optional[str] = "" + + # TODO define relationship with Element + # elements: List[Element] = Relationship(back_populates="step") + # thread: Optional[Thread] = Relationship(back_populates="steps") + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + +def flatten_args_kwargs(func, args, kwargs): + signature = inspect.signature(func) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()} + +def check_add_step_in_cot(step: "Step"): + is_message = step.type in [ + "user_message", + "assistant_message", + ] + is_cl_run = step.name in CL_RUN_NAMES and step.type == "run" + if config.ui.cot == "hidden" and not is_message and not is_cl_run: + return False + return True + +# Step decorator for async and sync functions, now using StepService +def step( + original_function: Optional[Callable] = None, + *, + name: Optional[str] = "", + type: Optional[str] = "undefined", + id: Optional[str] = None, + parent_id: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + language: Optional[str] = None, + show_input: Union[bool, str] = "json", + default_open: bool = False + ) -> Callable: + def wrapper(func: Callable): + nonlocal name + if not name: + name = func.__name__ + if inspect.iscoroutinefunction(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with StepService( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step: + try: + step.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = await func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except Exception: + step.is_error = True + step.output = str(result) + return result + return async_wrapper + else: + @wraps(func) + def sync_wrapper(*args, **kwargs): + with StepService( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step: + try: + step.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except Exception: + step.is_error = True + step.output = str(result) + return result + return sync_wrapper + func = original_function + if not func: + return wrapper + else: + return wrapper(func) + + +# StepService: business logic, context managers, and decorator support +class StepService: + def __init__(self, **kwargs): + self.step = Step(**kwargs) + self.elements = [] + self.fail_on_persist_error = False + self._input = "" + self._output = "" + + @property + def input(self): + return self._input + + @input.setter + def input(self, content: Union[Dict, str]): + self._input = self._process_content(content, set_language=False) + self.step.input = self._input + + @property + def output(self): + return self._output + + @output.setter + def output(self, content: Union[Dict, str]): + self._output = self._process_content(content, set_language=True) + self.step.output = self._output + + def _clean_content(self, content): + def handle_bytes(item): + if isinstance(item, bytes): + return "STRIPPED_BINARY_DATA" + elif isinstance(item, dict): + return {k: handle_bytes(v) for k, v in item.items()} + elif isinstance(item, list): + return [handle_bytes(i) for i in item] + elif isinstance(item, tuple): + return tuple(handle_bytes(i) for i in item) + return item + return handle_bytes(content) + + def _process_content(self, content, set_language=False): + if content is None: + return "" + content = self._clean_content(content) + if isinstance(content, (dict, list, tuple)): + try: + processed_content = json.dumps(content, indent=4, ensure_ascii=False) + if set_language: + self.step.language = "json" + except TypeError: + processed_content = str(content).replace("\\n", "\n") + if set_language: + self.step.language = "text" + elif isinstance(content, str): + processed_content = content + else: + processed_content = str(content).replace("\\n", "\n") + if set_language: + self.step.language = "text" + return processed_content + + def to_dict(self): + return self.step.dict() + + # Context manager support + async def __aenter__(self): + self.start = utc_now() + previous_steps = local_steps.get() or [] + parent_step = previous_steps[-1] if previous_steps else None + + if not self.parent_id: + if parent_step: + self.parent_id = parent_step.id + local_steps.set(previous_steps + [self]) + await self.send() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.end = utc_now() + + if exc_type: + self.output = str(exc_val) + self.is_error = True + + current_steps = local_steps.get() + if current_steps and self in current_steps: + current_steps.remove(self) + local_steps.set(current_steps) + + await self.update() + + def __enter__(self): + self.start = utc_now() + + previous_steps = local_steps.get() or [] + parent_step = previous_steps[-1] if previous_steps else None + + if not self.parent_id: + if parent_step: + self.parent_id = parent_step.id + local_steps.set(previous_steps + [self]) + + asyncio.create_task(self.send()) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = utc_now() + + if exc_type: + self.output = str(exc_val) + self.is_error = True + + current_steps = local_steps.get() + if current_steps and self in current_steps: + current_steps.remove(self) + local_steps.set(current_steps) + + asyncio.create_task(self.update()) + + # Business logic methods restored from original Step class + + async def update(self): + """ + Update a step already sent to the UI. + """ + if self.step.streaming: + self.step.streaming = False + + step_dict = self.step.model_dump(by_alias=True) + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.update_step(step_dict.copy())) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step update: {e!s}") + + # elements logic + tasks = [el.send(for_id=self.step.id) for el in getattr(self, 'elements', [])] + await asyncio.gather(*tasks) + + # UI update logic + from chainlit.context import check_add_step_in_cot, stub_step + if not check_add_step_in_cot(self.step): + await context.emitter.update_step(stub_step(self.step)) + else: + await context.emitter.update_step(step_dict) + + return True + + + async def remove(self): + """ + Remove a step already sent to the UI. + """ + step_dict = self.to_dict() + from chainlit.data import get_data_layer + from chainlit.logger import logger + from chainlit.context import context + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.delete_step(self.step.id)) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step deletion: {e!s}") + + await context.emitter.delete_step(step_dict) + return True + + + async def send(self): + from chainlit.config import config + from chainlit.data import get_data_layer + from chainlit.logger import logger + from chainlit.context import context, check_add_step_in_cot, stub_step + if self.step.persisted: + return self + + if getattr(config.code, "author_rename", None): + self.step.name = await config.code.author_rename(self.step.name) + + if self.step.streaming: + self.step.streaming = False + + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.create_step(step_dict.copy())) + self.step.persisted = True + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step creation: {e!s}") + + tasks = [el.send(for_id=self.step.id) for el in getattr(self, 'elements', [])] + await asyncio.gather(*tasks) + + if not check_add_step_in_cot(self.step): + await context.emitter.send_step(stub_step(self.step)) + else: + await context.emitter.send_step(step_dict) + + return self + + + async def stream_token(self, token: str, is_sequence=False, is_input=False): + """ + Sends a token to the UI. + Once all tokens have been streamed, call .send() to end the stream and persist the step if persistence is enabled. + """ + from chainlit.context import context, check_add_step_in_cot, stub_step + if not token: + return + + if is_sequence: + if is_input: + self.input = token + else: + self.output = token + else: + if is_input: + self.input += token + else: + self.output += token + + assert self.step.id + + if not check_add_step_in_cot(self.step): + await context.emitter.send_step(stub_step(self.step)) + return + + if not self.step.streaming: + self.step.streaming = True + step_dict = self.to_dict() + await context.emitter.stream_start(step_dict) + else: + await context.emitter.send_token( + id=self.step.id, token=token, is_sequence=is_sequence, is_input=is_input + ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 08a8fb9093..91a8db4a55 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -55,7 +55,8 @@ dependencies = [ "python-multipart>=0.0.18,<1.0.0", "pyjwt>=2.8.0,<3.0.0", "audioop-lts>=0.2.1,<0.3.0; python_version>='3.13'", - "pydantic-settings>=2.10.1" + "pydantic-settings>=2.10.1", + "sqlmodel>=0.0.24" ] [project.urls] diff --git a/backend/uv.lock b/backend/uv.lock index 2983388a36..9ca675ba59 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -679,7 +679,7 @@ wheels = [ [[package]] name = "chainlit" -version = "2.7.1.1" +version = "2.7.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -701,6 +701,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-multipart" }, { name = "python-socketio" }, + { name = "sqlmodel" }, { name = "starlette" }, { name = "syncer" }, { name = "tomli" }, @@ -795,6 +796,7 @@ requires-dist = [ { name = "semantic-kernel", marker = "extra == 'tests'", specifier = ">=1.24.0,<2.0.0" }, { name = "slack-bolt", marker = "extra == 'tests'", specifier = ">=1.18.1,<2.0.0" }, { name = "sqlalchemy", marker = "extra == 'custom-data'", specifier = ">=2.0.28,<3.0.0" }, + { name = "sqlmodel", specifier = ">=0.0.24" }, { name = "starlette", specifier = ">=0.47.2" }, { name = "syncer", specifier = ">=2.0.3,<3.0.0" }, { name = "tenacity", marker = "extra == 'tests'", specifier = ">=8.4.1,<9.0.0" }, @@ -5314,6 +5316,19 @@ asyncio = [ { name = "greenlet" }, ] +[[package]] +name = "sqlmodel" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" }, +] + [[package]] name = "sse-starlette" version = "3.0.2" From c23dfd65c428b8111cc258a66ee1c8011aefa48f Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 14:21:05 -0500 Subject: [PATCH 2/9] full mockup --- backend/alembic/README.md | 39 ++ backend/alembic/env.py | 52 +++ .../alembic/versions/0001_create_tables.py | 20 + backend/chainlit/cli/__init__.py | 31 ++ backend/chainlit/data/alembic/README.md | 39 ++ backend/chainlit/data/alembic/env.py | 52 +++ .../alembic/versions/0001_create_tables.py | 50 +++ backend/chainlit/data/sql_model.py | 373 ++++++++++++++++++ backend/chainlit/models/__init__.py | 5 + backend/chainlit/models/element.py | 296 ++++++++++++++ backend/chainlit/models/feedback.py | 31 ++ backend/chainlit/{types => models}/step.py | 321 +++++++-------- backend/chainlit/models/thread.py | 95 +++++ backend/chainlit/models/user.py | 35 ++ backend/chainlit/types/__init__.py | 1 - backend/pyproject.toml | 3 +- backend/uv.lock | 29 ++ 17 files changed, 1296 insertions(+), 176 deletions(-) create mode 100644 backend/alembic/README.md create mode 100644 backend/alembic/env.py create mode 100644 backend/alembic/versions/0001_create_tables.py create mode 100644 backend/chainlit/data/alembic/README.md create mode 100644 backend/chainlit/data/alembic/env.py create mode 100644 backend/chainlit/data/alembic/versions/0001_create_tables.py create mode 100644 backend/chainlit/data/sql_model.py create mode 100644 backend/chainlit/models/__init__.py create mode 100644 backend/chainlit/models/element.py create mode 100644 backend/chainlit/models/feedback.py rename backend/chainlit/{types => models}/step.py (75%) create mode 100644 backend/chainlit/models/thread.py create mode 100644 backend/chainlit/models/user.py delete mode 100644 backend/chainlit/types/__init__.py diff --git a/backend/alembic/README.md b/backend/alembic/README.md new file mode 100644 index 0000000000..5f504a1b51 --- /dev/null +++ b/backend/alembic/README.md @@ -0,0 +1,39 @@ +# Alembic Migrations for Chainlit SQLModelDataLayer + +This directory contains Alembic migration scripts for the SQLModel-based data layer. + +## Best Practices + +- **Do not use `SQLModel.metadata.create_all()` in production.** +- Always manage schema changes with Alembic migrations. +- Keep migration scripts in version control. +- Run migrations before starting the app, or enable auto-migration with `CHAINLIT_AUTO_MIGRATE=true`. + +## Usage + +1. **Configure your database URL** in `alembic.ini`: + ```ini + sqlalchemy.url = + ``` + +2. **Autogenerate a migration** (after changing models): + ```bash + alembic revision --autogenerate -m "Initial tables" + ``` + +3. **Apply migrations**: + ```bash + alembic upgrade head + ``` + +## Initial Migration + +The first migration should create all tables defined in `chainlit.models`. + +## env.py + +Alembic is configured to use `SQLModel.metadata` from `chainlit.models`. + +--- + +For more details, see the [Alembic documentation](https://alembic.sqlalchemy.org/en/latest/). diff --git a/backend/alembic/env.py b/backend/alembic/env.py new file mode 100644 index 0000000000..0435faa2af --- /dev/null +++ b/backend/alembic/env.py @@ -0,0 +1,52 @@ +import sys +import os +from logging.config import fileConfig +from sqlalchemy import engine_from_config, pool +from alembic import context + +# Add the parent directory to sys.path to import models +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from chainlit.models import SQLModel + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +target_metadata = SQLModel.metadata + +def run_migrations_offline(): + """ + Run migrations in 'offline' mode. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"} + ) + + with context.begin_transaction(): + context.run_migrations() + +def run_migrations_online(): + """ + Run migrations in 'online' mode. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/alembic/versions/0001_create_tables.py b/backend/alembic/versions/0001_create_tables.py new file mode 100644 index 0000000000..c04927f893 --- /dev/null +++ b/backend/alembic/versions/0001_create_tables.py @@ -0,0 +1,20 @@ +""" +Initial migration: create all tables for SQLModelDataLayer +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '0001_create_tables' +down_revision = None +branch_labels = None +depends_on = None + +def upgrade(): + # This will be auto-generated by Alembic's autogenerate feature, + # but for best practice, you can start with an empty upgrade and run autogenerate. + pass + +def downgrade(): + # Drop all tables (if needed) + pass diff --git a/backend/chainlit/cli/__init__.py b/backend/chainlit/cli/__init__.py index 5a03136e67..ed3b122c2b 100644 --- a/backend/chainlit/cli/__init__.py +++ b/backend/chainlit/cli/__init__.py @@ -47,6 +47,15 @@ def cli(): # Define the function to run Chainlit with provided options def run_chainlit(target: str): + import os + import subprocess + + def auto_run_alembic_upgrade(): + try: + subprocess.run(["alembic", "upgrade", "head"], check=True) + logger.info("Alembic migrations applied (upgrade to head).") + except Exception as e: + logger.error(f"Failed to run Alembic migrations: {e}") host = os.environ.get("CHAINLIT_HOST", DEFAULT_HOST) port = int(os.environ.get("CHAINLIT_PORT", DEFAULT_PORT)) root_path = os.environ.get("CHAINLIT_ROOT_PATH", DEFAULT_ROOT_PATH) @@ -76,6 +85,28 @@ def run_chainlit(target: str): config.run.module_name = target load_module(config.run.module_name) + # Check if SQLModelDataLayer is used and warn about Alembic migrations + data_layer_func = getattr(config.code, "data_layer", None) + if data_layer_func: + try: + dl_instance = data_layer_func() + from chainlit.data.sql_model import SQLModelDataLayer + if isinstance(dl_instance, SQLModelDataLayer): + # Get current version + try: + from chainlit.version import __version__ + except Exception: + __version__ = "unknown" + logger.info(f"SQLModelDataLayer detected. Chainlit version: {__version__}.") + auto_migrate = os.environ.get("CHAINLIT_AUTO_MIGRATE", "false").lower() in ["true", "1", "yes"] + if auto_migrate: + logger.info("Auto-migration enabled. Running Alembic migrations...") + auto_run_alembic_upgrade() + else: + logger.info("Auto-migration disabled. Run 'alembic upgrade head' after updating models or upgrading Chainlit.") + except Exception as e: + logger.warning(f"Could not check data layer type: {e}") + ensure_jwt_secret() assert_app() diff --git a/backend/chainlit/data/alembic/README.md b/backend/chainlit/data/alembic/README.md new file mode 100644 index 0000000000..5f504a1b51 --- /dev/null +++ b/backend/chainlit/data/alembic/README.md @@ -0,0 +1,39 @@ +# Alembic Migrations for Chainlit SQLModelDataLayer + +This directory contains Alembic migration scripts for the SQLModel-based data layer. + +## Best Practices + +- **Do not use `SQLModel.metadata.create_all()` in production.** +- Always manage schema changes with Alembic migrations. +- Keep migration scripts in version control. +- Run migrations before starting the app, or enable auto-migration with `CHAINLIT_AUTO_MIGRATE=true`. + +## Usage + +1. **Configure your database URL** in `alembic.ini`: + ```ini + sqlalchemy.url = + ``` + +2. **Autogenerate a migration** (after changing models): + ```bash + alembic revision --autogenerate -m "Initial tables" + ``` + +3. **Apply migrations**: + ```bash + alembic upgrade head + ``` + +## Initial Migration + +The first migration should create all tables defined in `chainlit.models`. + +## env.py + +Alembic is configured to use `SQLModel.metadata` from `chainlit.models`. + +--- + +For more details, see the [Alembic documentation](https://alembic.sqlalchemy.org/en/latest/). diff --git a/backend/chainlit/data/alembic/env.py b/backend/chainlit/data/alembic/env.py new file mode 100644 index 0000000000..0435faa2af --- /dev/null +++ b/backend/chainlit/data/alembic/env.py @@ -0,0 +1,52 @@ +import sys +import os +from logging.config import fileConfig +from sqlalchemy import engine_from_config, pool +from alembic import context + +# Add the parent directory to sys.path to import models +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from chainlit.models import SQLModel + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +target_metadata = SQLModel.metadata + +def run_migrations_offline(): + """ + Run migrations in 'offline' mode. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"} + ) + + with context.begin_transaction(): + context.run_migrations() + +def run_migrations_online(): + """ + Run migrations in 'online' mode. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/chainlit/data/alembic/versions/0001_create_tables.py b/backend/chainlit/data/alembic/versions/0001_create_tables.py new file mode 100644 index 0000000000..ca1067b11c --- /dev/null +++ b/backend/chainlit/data/alembic/versions/0001_create_tables.py @@ -0,0 +1,50 @@ +""" +Initial migration: migrate camelCase columns to snake_case for SQLModelDataLayer +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '0001_create_tables' +down_revision = None +branch_labels = None +depends_on = None + +def upgrade(): + # Thread table: rename camelCase columns to snake_case + with op.batch_alter_table('thread') as batch_op: + batch_op.rename_column('createdAt', 'created_at') + batch_op.rename_column('userId', 'user_id') + batch_op.rename_column('userIdentifier', 'user_identifier') + # tags and metadata are already snake_case or compatible + # Repeat for other tables (User, Step, Element, Feedback) + with op.batch_alter_table('user') as batch_op: + batch_op.rename_column('createdAt', 'created_at') + with op.batch_alter_table('step') as batch_op: + batch_op.rename_column('threadId', 'thread_id') + batch_op.rename_column('parentId', 'parent_id') + batch_op.rename_column('createdAt', 'created_at') + with op.batch_alter_table('element') as batch_op: + batch_op.rename_column('threadId', 'thread_id') + batch_op.rename_column('objectKey', 'object_key') + with op.batch_alter_table('feedback') as batch_op: + batch_op.rename_column('forId', 'for_id') + # If tables do not exist, Alembic will error; users should run this only once during migration. + +def downgrade(): + # Reverse the renames for downgrade + with op.batch_alter_table('thread') as batch_op: + batch_op.rename_column('created_at', 'createdAt') + batch_op.rename_column('user_id', 'userId') + batch_op.rename_column('user_identifier', 'userIdentifier') + with op.batch_alter_table('user') as batch_op: + batch_op.rename_column('created_at', 'createdAt') + with op.batch_alter_table('step') as batch_op: + batch_op.rename_column('thread_id', 'threadId') + batch_op.rename_column('parent_id', 'parentId') + batch_op.rename_column('created_at', 'createdAt') + with op.batch_alter_table('element') as batch_op: + batch_op.rename_column('thread_id', 'threadId') + batch_op.rename_column('object_key', 'objectKey') + with op.batch_alter_table('feedback') as batch_op: + batch_op.rename_column('for_id', 'forId') diff --git a/backend/chainlit/data/sql_model.py b/backend/chainlit/data/sql_model.py new file mode 100644 index 0000000000..978ad3c1df --- /dev/null +++ b/backend/chainlit/data/sql_model.py @@ -0,0 +1,373 @@ +from sqlmodel import SQLModel, create_engine, select +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine, async_sessionmaker +from contextlib import asynccontextmanager +from chainlit.data.base import BaseDataLayer +from chainlit.data.storage_clients.base import BaseStorageClient +from chainlit.data.utils import queue_until_user_message +from typing import Optional, Any, Dict +from chainlit.models import PersistedUser, User, Feedback, Thread, Element, Step +import json +import ssl +import uuid +from chainlit.logger import logger +from chainlit.types import ( + PaginatedResponse, + Pagination, + ThreadFilter, + PageInfo +) + +class SQLModelDataLayer(BaseDataLayer): + def __init__( + self, + conninfo: str, + connect_args: Optional[dict[str, Any]] = None, + ssl_require: bool = False, + storage_provider: Optional[BaseStorageClient] = None, + user_thread_limit: Optional[int] = 1000, + show_logger: Optional[bool] = False, + ): + self._conninfo = conninfo + self.user_thread_limit = user_thread_limit + self.show_logger = show_logger + if connect_args is None: + connect_args = {} + if ssl_require: + # Create an SSL context to require an SSL connection + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + connect_args["ssl"] = ssl_context + self.engine = AsyncEngine( + create_async_engine( + self._conninfo, + connect_args=connect_args, + echo=self.show_logger, + ) + ) + self.async_session = async_sessionmaker( + bind=self.engine, expire_on_commit=False, class_=AsyncSession + ) + if storage_provider: + self.storage_provider: Optional[BaseStorageClient] = storage_provider + if self.show_logger: + logger.info("SQLModel storage client initialized") + else: + self.storage_provider = None + logger.warning( + "SQLModel storage client is not initialized and elements will not be persisted!" + ) + + async def init_db(self): + """ + Explicitly create tables for development or testing only. + In production, use Alembic migrations! + """ + logger.warning("init_db should only be used for local development or tests. Use Alembic for production migrations.") + async with self.engine.begin() as conn: + # await conn.run_sync(SQLModel.metadata.drop_all) # Uncomment to drop tables + await conn.run_sync(SQLModel.metadata.create_all) + + async def create_user(self, user: User) -> Optional[PersistedUser]: + async with self.async_session() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == user.identifier)) + existing = result.scalar_one_or_none() + if not existing: + db_user = PersistedUser(identifier=user.identifier, metadata=user.metadata) + session.add(db_user) + await session.commit() + await session.refresh(db_user) + return PersistedUser(identifier=db_user.identifier, metadata=db_user.metadata) + return PersistedUser(identifier=existing.identifier, metadata=existing.metadata) + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + async with self.async_session() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + if user: + return PersistedUser(identifier=user.identifier, metadata=user.metadata) + return None + + async def update_user(self, identifier: str, metadata: Optional[dict] = None) -> Optional[PersistedUser]: + async with self.async_session() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + if user: + if metadata is not None: + user.metadata = metadata + await session.commit() + await session.refresh(user) + return PersistedUser(identifier=user.identifier, metadata=user.metadata) + return None + + async def delete_user(self, identifier: str) -> bool: + async with self.async_session() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + if user: + await session.delete(user) + await session.commit() + return True + return False + + async def create_thread(self, thread_data: dict) -> Optional[Dict]: + async with self.async_session() as session: + db_thread = Thread.model_validate(thread_data) + session.add(db_thread) + await session.commit() + await session.refresh(db_thread) + return db_thread.to_dict() + + async def get_thread(self, thread_id: str) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + return thread.to_dict() + return None + + async def get_thread_author(self, thread_id: str) -> str: + async with self.async_session() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread: Thread = result.scalar_one_or_none() + if thread and thread.user_identifier: + return thread.user_identifier + return "" + + async def update_thread(self, thread_id: str, **kwargs) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + for k, v in kwargs.items(): + setattr(thread, k, v) + await session.commit() + await session.refresh(thread) + return thread.to_dict() + return None + + async def delete_thread(self, thread_id: str) -> bool: + async with self.async_session() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + await session.delete(thread) + await session.commit() + return True + return False + + @queue_until_user_message() + async def create_step(self, step_data: dict) -> Optional[Dict]: + async with self.async_session() as session: + db_step = Step.model_validate(step_data) + session.add(db_step) + await session.commit() + await session.refresh(db_step) + return db_step.to_dict() + + async def get_step(self, step_id: str) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + return step.to_dict() + return None + + @queue_until_user_message() + async def update_step(self, step_id: str, **kwargs) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + for k, v in kwargs.items(): + setattr(step, k, v) + await session.commit() + await session.refresh(step) + return step.to_dict() + return None + + @queue_until_user_message() + async def delete_step(self, step_id: str) -> bool: + async with self.async_session() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + await session.delete(step) + await session.commit() + return True + return False + + async def upsert_feedback(self, feedback: Feedback) -> str: + feedback_id = feedback.id or str(uuid.uuid4()) + feedback_dict = feedback.dict() + feedback_dict["id"] = feedback_id + async with self.async_session() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + db_feedback = result.scalar_one_or_none() + if db_feedback: + for k, v in feedback_dict.items(): + setattr(db_feedback, k, v) + else: + db_feedback = Feedback.model_validate(feedback_dict) + session.add(db_feedback) + await session.commit() + await session.refresh(db_feedback) + return db_feedback.id + + async def get_feedback(self, feedback_id: str) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + feedback = result.scalar_one_or_none() + if feedback: + return feedback.to_dict() + return None + + async def delete_feedback(self, feedback_id: str) -> bool: + async with self.async_session() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + feedback = result.scalar_one_or_none() + if feedback: + await session.delete(feedback) + await session.commit() + return True + return False + + async def get_element(self, thread_id: str, element_id: str) -> Optional[Dict]: + async with self.async_session() as session: + result = await session.execute( + select(Element).where(Element.thread_id == thread_id, Element.id == element_id) + ) + element = result.scalar_one_or_none() + if element: + # props should be deserialized if stored as JSON string + props = element.props + if isinstance(props, str): + props = json.loads(props) + return { + **element.to_dict(), + "props": props, + } + return None + + @queue_until_user_message() + async def create_element(self, element: "Element"): + if self.show_logger: + logger.info(f"SQLModel: create_element, element_id = {element.id}") + + if not self.storage_provider: + logger.warning("SQLModel: create_element error. No blob_storage_client is configured!") + return + if not element.for_id: + return + + content: Optional[bytes] = None + if element.path: + import aiofiles + async with aiofiles.open(element.path, "rb") as f: + content = await f.read() + elif element.url: + import aiohttp + async with aiohttp.ClientSession() as session_http: + async with session_http.get(element.url) as response: + if response.status == 200: + content = await response.read() + else: + content = None + elif element.content: + content = element.content + else: + raise ValueError("Element url, path or content must be provided") + if content is None: + raise ValueError("Content is None, cannot upload file") + + user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown" + file_object_key = f"{user_id}/{element.id}" + (f"/{element.name}" if element.name else "") + + if not element.mime: + element.mime = "application/octet-stream" + + uploaded_file = await self.storage_provider.upload_file( + object_key=file_object_key, data=content, mime=element.mime, overwrite=True + ) + if not uploaded_file: + raise ValueError("SQLModel Error: create_element, Failed to persist data in storage_provider") + + element_dict = element.to_dict() + element_dict["url"] = uploaded_file.get("url") + element_dict["objectKey"] = uploaded_file.get("object_key") + element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} + if "props" in element_dict_cleaned: + element_dict_cleaned["props"] = json.dumps(element_dict_cleaned["props"]) + + async with self.async_session() as session: + db_element = Element.model_validate(element_dict_cleaned) + session.add(db_element) + await session.commit() + await session.refresh(db_element) + return db_element.to_dict() + + @queue_until_user_message() + async def delete_element(self, element_id: str, thread_id: Optional[str] = None): + if self.show_logger: + logger.info(f"SQLModel: delete_element, element_id={element_id}") + + async with self.async_session() as session: + query = select(Element).where(Element.id == element_id) + if thread_id: + query = query.where(Element.thread_id == thread_id) + result = await session.execute(query) + element = result.scalar_one_or_none() + element_dict = element.to_dict() if element else None + if ( + self.storage_provider is not None + and element is not None + and getattr(element_dict, "objectKey", None) + ): + await self.storage_provider.delete_file(object_key=element['objectKey']) + if element: + await session.delete(element) + await session.commit() + + async def build_debug_url(self) -> str: + # Implement as needed, or return empty string for now + return "" + + async def list_threads( + self, pagination: Pagination, filters: ThreadFilter + ) -> PaginatedResponse[Dict]: + # Fetch threads for a user, apply pagination and filters + async with self.async_session() as session: + if filters.userId: + query = select(Thread).where(Thread.user_id == filters.userId) + result = await session.execute(query) + threads = result.scalars().all() + # Apply search filter + if filters.search: + threads = [t for t in threads if filters.search.lower() in (t.name or '').lower()] + # Apply feedback filter (if present) + if filters.feedback is not None: + # This requires joining with Feedback, so for now, skip or implement as needed + pass + # Pagination + start = 0 + if pagination.cursor: + for i, t in enumerate(threads): + if t.id == pagination.cursor: + start = i + 1 + break + end = start + pagination.first + paginated_threads = threads[start:end] + has_next_page = len(threads) > end + start_cursor = paginated_threads[0].id if paginated_threads else None + end_cursor = paginated_threads[-1].id if paginated_threads else None + # Convert to dicts + data = [t.to_dict() for t in paginated_threads] + # Build PaginatedResponse + return PaginatedResponse( + pageInfo=PageInfo( + hasNextPage=has_next_page, + startCursor=start_cursor, + endCursor=end_cursor, + ), + data=data, + ) \ No newline at end of file diff --git a/backend/chainlit/models/__init__.py b/backend/chainlit/models/__init__.py new file mode 100644 index 0000000000..e2b017df11 --- /dev/null +++ b/backend/chainlit/models/__init__.py @@ -0,0 +1,5 @@ +from .step import Step as Step +from .user import User, PersistedUser +from .thread import Thread +from .feedback import Feedback, UpdateFeedbackRequest, DeleteFeedbackRequest +from .element import Element, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py new file mode 100644 index 0000000000..0d31ae271b --- /dev/null +++ b/backend/chainlit/models/element.py @@ -0,0 +1,296 @@ +from typing import Optional, Dict, List, Literal, Union, ClassVar, TypeVar, Any, cast +from sqlmodel import SQLModel, Field +import uuid +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel +from syncer import asyncio +import filetype +from chainlit.context import context +from chainlit.data import get_data_layer +from chainlit.logger import logger +from chainlit.element import Task, TaskStatus +import json + +mime_types = { + "text": "text/plain", + "tasklist": "application/json", + "plotly": "application/json", +} + +ElementType = Literal[ + "image", + "text", + "pdf", + "tasklist", + "audio", + "video", + "file", + "plotly", + "dataframe", + "custom", +] +ElementDisplay = Literal["inline", "side", "page"] +ElementSize = Literal["small", "medium", "large"] + +class Element(SQLModel, table=True): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + thread_id: Optional[str] = None + type: ElementType + name: str = "" + url: Optional[str] = None + path: Optional[str] = None + object_key: Optional[str] = None + chainlit_key: Optional[str] = None + display: ElementDisplay + size: Optional[ElementSize] = None + language: Optional[str] = None + mime: Optional[str] = None + for_id: Optional[str] = None + # Add other common fields as needed + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + def to_dict(self): + return self.model_dump(by_alias=True) + + @classmethod + def from_dict(cls, **kwargs): + type_ = kwargs.get("type", "file") + if type_ == "image": + return Image(**kwargs) + elif type_ == "audio": + return Audio(**kwargs) + elif type_ == "video": + return Video(**kwargs) + elif type_ == "plotly": + return Plotly(**kwargs) + elif type_ == "custom": + return CustomElement(**kwargs) + elif type_ == "pdf": + return Pdf(**kwargs) + elif type_ == "tasklist": + return TaskList(**kwargs) + elif type_ == "dataframe": + return Dataframe(**kwargs) + elif type_ == "text": + return Text(**kwargs) + else: + return File(**kwargs) + + @classmethod + def infer_type_from_mime(cls, mime_type: str): + """Infer the element type from a mime type. Useful to know which element to instantiate from a file upload.""" + if "image" in mime_type: + return "image" + + elif mime_type == "application/pdf": + return "pdf" + + elif "audio" in mime_type: + return "audio" + + elif "video" in mime_type: + return "video" + + else: + return "file" + + async def _create(self, persist=True) -> bool: + if getattr(self, "persisted", False) and not getattr(self, "updatable", False): + return True + + data_layer = get_data_layer() + if data_layer and persist: + try: + import asyncio + task = asyncio.create_task(data_layer.create_element(self)) + except Exception as e: + logger.error(f"Failed to create element: {e!s}") + + if not self.url and (not self.chainlit_key or getattr(self, "updatable", False)): + file_dict = await context.session.persist_file( + name=self.name, + path=self.path, + content=self.content, + mime=self.mime or "", + ) + self.chainlit_key = file_dict["id"] + + self.persisted = True + return True + + async def remove(self): + data_layer = get_data_layer() + if data_layer: + await data_layer.delete_element(self.id, self.thread_id) + await context.emitter.emit("remove_element", {"id": self.id}) + + async def send(self, for_id: str, persist=True): + self.for_id = for_id + + if not self.mime: + if hasattr(self, "type") and self.type in mime_types: + self.mime = mime_types[self.type] + elif self.path or isinstance(self.content, (bytes, bytearray)): + import filetype + file_type = filetype.guess(self.path or self.content) + if file_type: + self.mime = file_type.mime + elif self.url: + import mimetypes + self.mime = mimetypes.guess_type(self.url)[0] + + await self._create(persist=persist) + + if not self.url and not self.chainlit_key: + raise ValueError("Must provide url or chainlit key to send element") + + await context.emitter.send_element(self.to_dict()) + +ElementBased = TypeVar("ElementBased", bound=Element) + +# Subclasses for runtime logic (not persisted, but can be instantiated from Element) +class Image(Element): + type: ClassVar[ElementType] = "image" + size: Optional[str] = "medium" + +class Text(Element): + type: ElementType = "text" + language: Optional[str] = None + +class Pdf(Element): + type: ElementType = "pdf" + mime: str = "application/pdf" + page: Optional[int] = None + +class Pyplot(Element): + """Useful to send a pyplot to the UI.""" + + # We reuse the frontend image element to display the chart + type: ClassVar[ElementType] = "image" + + size: ElementSize = "medium" + # The type is set to Any because the figure is not serializable + # and its actual type is checked in __post_init__. + figure: Any = None + content: bytes = b"" + + def __post_init__(self) -> None: + from matplotlib.figure import Figure + from io import BytesIO + + if not isinstance(self.figure, Figure): + raise TypeError("figure must be a matplotlib.figure.Figure") + + image = BytesIO() + self.figure.savefig( + image, dpi=200, bbox_inches="tight", backend="Agg", format="png" + ) + self.content = image.getvalue() + + super().__post_init__() + +class TaskList(Element): + type: ElementType = "tasklist" + tasks: List = Field(default_factory=list) + status: str = "Ready" + name: str = "tasklist" + content: str = "dummy content to pass validation" + + def __post_init__(self) -> None: + super().__post_init__() + self.updatable = True + + async def add_task(self, task: Task): + self.tasks.append(task) + + async def update(self): + await self.send() + + async def send(self): + await self.preprocess_content() + await super().send(for_id="") + + async def preprocess_content(self): + # serialize enum + tasks = [ + {"title": task.title, "status": task.status.value, "forId": task.forId} + for task in self.tasks + ] + + # store stringified json in content so that it's correctly stored in the database + self.content = json.dumps( + { + "status": self.status, + "tasks": tasks, + }, + indent=4, + ensure_ascii=False, + ) + +class Audio(Element): + type: ClassVar[ElementType] = "audio" + auto_play: bool = False + +class Video(Element): + type: ClassVar[ElementType] = "video" + size: ElementSize = "medium" + # Override settings for each type of player in ReactPlayer + # https://github.com/cookpete/react-player?tab=readme-ov-file#config-prop + player_config: Optional[dict] = None + +class File(Element): + type: ElementType = "file" + +class Plotly(Element): + type: ElementType = "plotly" + size: Optional[str] = "medium" + figure: Optional[Any] = None + content: str = "" + + def __post_init__(self) -> None: + from plotly import graph_objects as go, io as pio + + if not isinstance(self.figure, go.Figure): + raise TypeError("figure must be a plotly.graph_objects.Figure") + + self.figure.layout.autosize = True + self.figure.layout.width = None + self.figure.layout.height = None + self.content = pio.to_json(self.figure, validate=True) + self.mime = "application/json" + + super().__post_init__() + +class Dataframe(Element): + type: ElementType = "dataframe" + size: Optional[str] = "large" + data: Any = None # The type is Any because it is checked in __post_init__. + + def __post_init__(self) -> None: + """Ensures the data is a pandas DataFrame and converts it to JSON.""" + from pandas import DataFrame + + if not isinstance(self.data, DataFrame): + raise TypeError("data must be a pandas.DataFrame") + + self.content = self.data.to_json(orient="split", date_format="iso") + super().__post_init__() + +class CustomElement(Element): + """Useful to send a custom element to the UI.""" + + type: ClassVar[ElementType] = "custom" + mime: str = "application/json" + props: Dict = Field(default_factory=dict) + + def __post_init__(self) -> None: + self.content = json.dumps(self.props) + super().__post_init__() + self.updatable = True + + async def update(self): + await super().send(self.for_id) \ No newline at end of file diff --git a/backend/chainlit/models/feedback.py b/backend/chainlit/models/feedback.py new file mode 100644 index 0000000000..eee883084d --- /dev/null +++ b/backend/chainlit/models/feedback.py @@ -0,0 +1,31 @@ +from typing import Dict, Optional, Literal +from sqlmodel import SQLModel, Field +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel + +FeedbackStrategy = Literal["BINARY"] + +class Feedback(SQLModel, table=True): + id: Optional[str] = Field(default=None, primary_key=True) + for_id: str + value: Literal[0, 1] + thread_id: Optional[str] = None + comment: Optional[str] = None + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + def to_dict(self): + data = self.model_dump(by_alias=True) + data.pop("threadId", None) + return data + +class UpdateFeedbackRequest(BaseModel): + feedback: Feedback + session_id: str + +class DeleteFeedbackRequest(BaseModel): + feedbackId: str diff --git a/backend/chainlit/types/step.py b/backend/chainlit/models/step.py similarity index 75% rename from backend/chainlit/types/step.py rename to backend/chainlit/models/step.py index c696a2a8aa..e064c83ce7 100644 --- a/backend/chainlit/types/step.py +++ b/backend/chainlit/models/step.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal from sqlmodel import SQLModel, Field - +from pydantic import PrivateAttr # If you want to keep compatibility with literalai types, import as needed from literalai import BaseGeneration from pydantic import ConfigDict @@ -29,6 +29,7 @@ StepType = Union[TrueStepType, MessageStepType] + class Step(SQLModel, table=True): id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) name: str = "" @@ -59,127 +60,29 @@ class Step(SQLModel, table=True): populate_by_name=True, ) -def flatten_args_kwargs(func, args, kwargs): - signature = inspect.signature(func) - bound_arguments = signature.bind(*args, **kwargs) - bound_arguments.apply_defaults() - return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()} - -def check_add_step_in_cot(step: "Step"): - is_message = step.type in [ - "user_message", - "assistant_message", - ] - is_cl_run = step.name in CL_RUN_NAMES and step.type == "run" - if config.ui.cot == "hidden" and not is_message and not is_cl_run: - return False - return True - -# Step decorator for async and sync functions, now using StepService -def step( - original_function: Optional[Callable] = None, - *, - name: Optional[str] = "", - type: Optional[str] = "undefined", - id: Optional[str] = None, - parent_id: Optional[str] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - language: Optional[str] = None, - show_input: Union[bool, str] = "json", - default_open: bool = False - ) -> Callable: - def wrapper(func: Callable): - nonlocal name - if not name: - name = func.__name__ - if inspect.iscoroutinefunction(func): - @wraps(func) - async def async_wrapper(*args, **kwargs): - async with StepService( - type=type, - name=name, - id=id, - parent_id=parent_id, - tags=tags, - language=language, - show_input=show_input, - default_open=default_open, - metadata=metadata, - ) as step: - try: - step.input = flatten_args_kwargs(func, args, kwargs) - except Exception: - pass - result = await func(*args, **kwargs) - try: - if result and not step.output: - step.output = result - except Exception: - step.is_error = True - step.output = str(result) - return result - return async_wrapper - else: - @wraps(func) - def sync_wrapper(*args, **kwargs): - with StepService( - type=type, - name=name, - id=id, - parent_id=parent_id, - tags=tags, - language=language, - show_input=show_input, - default_open=default_open, - metadata=metadata, - ) as step: - try: - step.input = flatten_args_kwargs(func, args, kwargs) - except Exception: - pass - result = func(*args, **kwargs) - try: - if result and not step.output: - step.output = result - except Exception: - step.is_error = True - step.output = str(result) - return result - return sync_wrapper - func = original_function - if not func: - return wrapper - else: - return wrapper(func) - - -# StepService: business logic, context managers, and decorator support -class StepService: - def __init__(self, **kwargs): - self.step = Step(**kwargs) - self.elements = [] - self.fail_on_persist_error = False - self._input = "" - self._output = "" + # Private attributes for business logic (not persisted or serialized) + _elements: Optional[List[Element]] = PrivateAttr(default_factory=list) + _fail_on_persist_error: bool = PrivateAttr(default=False) + _input: str = PrivateAttr(default="") + _output: str = PrivateAttr(default="") @property - def input(self): + def input_value(self): return self._input - @input.setter - def input(self, content: Union[Dict, str]): + @input_value.setter + def input_value(self, content: Union[Dict, str]): self._input = self._process_content(content, set_language=False) - self.step.input = self._input + self.input = self._input @property - def output(self): + def output_value(self): return self._output - @output.setter - def output(self, content: Union[Dict, str]): + @output_value.setter + def output_value(self, content: Union[Dict, str]): self._output = self._process_content(content, set_language=True) - self.step.output = self._output + self.output = self._output def _clean_content(self, content): def handle_bytes(item): @@ -202,21 +105,21 @@ def _process_content(self, content, set_language=False): try: processed_content = json.dumps(content, indent=4, ensure_ascii=False) if set_language: - self.step.language = "json" + self.language = "json" except TypeError: processed_content = str(content).replace("\\n", "\n") if set_language: - self.step.language = "text" + self.language = "text" elif isinstance(content, str): processed_content = content else: processed_content = str(content).replace("\\n", "\n") if set_language: - self.step.language = "text" + self.language = "text" return processed_content def to_dict(self): - return self.step.dict() + return self.model_dump(by_alias=True) # Context manager support async def __aenter__(self): @@ -224,9 +127,8 @@ async def __aenter__(self): previous_steps = local_steps.get() or [] parent_step = previous_steps[-1] if previous_steps else None - if not self.parent_id: - if parent_step: - self.parent_id = parent_step.id + if not self.parent_id and parent_step: + self.parent_id = parent_step.id local_steps.set(previous_steps + [self]) await self.send() return self @@ -235,7 +137,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.end = utc_now() if exc_type: - self.output = str(exc_val) + self.output_value = str(exc_val) self.is_error = True current_steps = local_steps.get() @@ -256,14 +158,14 @@ def __enter__(self): self.parent_id = parent_step.id local_steps.set(previous_steps + [self]) - asyncio.create_task(self.send()) + task = asyncio.create_task(self.send()) return self def __exit__(self, exc_type, exc_val, exc_tb): self.end = utc_now() if exc_type: - self.output = str(exc_val) + self.output_value = str(exc_val) self.is_error = True current_steps = local_steps.get() @@ -271,55 +173,41 @@ def __exit__(self, exc_type, exc_val, exc_tb): current_steps.remove(self) local_steps.set(current_steps) - asyncio.create_task(self.update()) - - # Business logic methods restored from original Step class + task = asyncio.create_task(self.update()) async def update(self): - """ - Update a step already sent to the UI. - """ - if self.step.streaming: - self.step.streaming = False + if self.streaming: + self.streaming = False - step_dict = self.step.model_dump(by_alias=True) + step_dict = self.to_dict() data_layer = get_data_layer() if data_layer: try: - asyncio.create_task(data_layer.update_step(step_dict.copy())) + task = asyncio.create_task(data_layer.update_step(step_dict.copy())) except Exception as e: if self.fail_on_persist_error: raise e logger.error(f"Failed to persist step update: {e!s}") - # elements logic - tasks = [el.send(for_id=self.step.id) for el in getattr(self, 'elements', [])] + tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] await asyncio.gather(*tasks) - # UI update logic from chainlit.context import check_add_step_in_cot, stub_step - if not check_add_step_in_cot(self.step): - await context.emitter.update_step(stub_step(self.step)) + if not check_add_step_in_cot(self): + await context.emitter.update_step(stub_step(self)) else: await context.emitter.update_step(step_dict) return True - async def remove(self): - """ - Remove a step already sent to the UI. - """ step_dict = self.to_dict() - from chainlit.data import get_data_layer - from chainlit.logger import logger - from chainlit.context import context data_layer = get_data_layer() if data_layer: try: - asyncio.create_task(data_layer.delete_step(self.step.id)) + task = asyncio.create_task(data_layer.delete_step(self.id)) except Exception as e: if self.fail_on_persist_error: raise e @@ -328,75 +216,160 @@ async def remove(self): await context.emitter.delete_step(step_dict) return True - async def send(self): - from chainlit.config import config - from chainlit.data import get_data_layer - from chainlit.logger import logger - from chainlit.context import context, check_add_step_in_cot, stub_step - if self.step.persisted: + if self.persisted: return self if getattr(config.code, "author_rename", None): - self.step.name = await config.code.author_rename(self.step.name) + self.name = await config.code.author_rename(self.name) - if self.step.streaming: - self.step.streaming = False + if self.streaming: + self.streaming = False step_dict = self.to_dict() data_layer = get_data_layer() if data_layer: try: - asyncio.create_task(data_layer.create_step(step_dict.copy())) - self.step.persisted = True + task = asyncio.create_task(data_layer.create_step(step_dict.copy())) + self.persisted = True except Exception as e: if self.fail_on_persist_error: raise e logger.error(f"Failed to persist step creation: {e!s}") - tasks = [el.send(for_id=self.step.id) for el in getattr(self, 'elements', [])] + tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] await asyncio.gather(*tasks) - if not check_add_step_in_cot(self.step): - await context.emitter.send_step(stub_step(self.step)) + from chainlit.context import check_add_step_in_cot, stub_step + if not check_add_step_in_cot(self): + await context.emitter.send_step(stub_step(self)) else: await context.emitter.send_step(step_dict) return self - async def stream_token(self, token: str, is_sequence=False, is_input=False): - """ - Sends a token to the UI. - Once all tokens have been streamed, call .send() to end the stream and persist the step if persistence is enabled. - """ from chainlit.context import context, check_add_step_in_cot, stub_step if not token: return if is_sequence: if is_input: - self.input = token + self.input_value = token else: - self.output = token + self.output_value = token else: if is_input: - self.input += token + self.input_value += token else: - self.output += token + self.output_value += token - assert self.step.id + assert self.id - if not check_add_step_in_cot(self.step): - await context.emitter.send_step(stub_step(self.step)) + if not check_add_step_in_cot(self): + await context.emitter.send_step(stub_step(self)) return - if not self.step.streaming: - self.step.streaming = True + if not self.streaming: + self.streaming = True step_dict = self.to_dict() await context.emitter.stream_start(step_dict) else: await context.emitter.send_token( - id=self.step.id, token=token, is_sequence=is_sequence, is_input=is_input + id=self.id, token=token, is_sequence=is_sequence, is_input=is_input ) + +def flatten_args_kwargs(func, args, kwargs): + signature = inspect.signature(func) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()} + +def check_add_step_in_cot(step: Step): + is_message = step.type in [ + "user_message", + "assistant_message", + ] + is_cl_run = step.name in CL_RUN_NAMES and step.type == "run" + if config.ui.cot == "hidden" and not is_message and not is_cl_run: + return False + return True + +# Step decorator for async and sync functions, now using StepService +def step( + original_function: Optional[Callable] = None, + *, + name: Optional[str] = "", + type: Optional[str] = "undefined", + id: Optional[str] = None, + parent_id: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + language: Optional[str] = None, + show_input: Union[bool, str] = "json", + default_open: bool = False + ) -> Callable: + def wrapper(func: Callable): + nonlocal name + if not name: + name = func.__name__ + if inspect.iscoroutinefunction(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with Step( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step: + try: + step.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = await func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except Exception: + step.is_error = True + step.output = str(result) + return result + return async_wrapper + else: + @wraps(func) + def sync_wrapper(*args, **kwargs): + with Step( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step: + try: + step.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = func(*args, **kwargs) + try: + if result and not step.output: + step.output = result + except Exception: + step.is_error = True + step.output = str(result) + return result + return sync_wrapper + func = original_function + if not func: + return wrapper + else: + return wrapper(func) diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py new file mode 100644 index 0000000000..6b05c2b8d8 --- /dev/null +++ b/backend/chainlit/models/thread.py @@ -0,0 +1,95 @@ + +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Protocol, TypeVar, Union, Self +from sqlmodel import SQLModel, Field +from pydantic import PrivateAttr, BaseModel +import uuid +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel + +if TYPE_CHECKING: + from chainlit.element import ElementDict + from chainlit.step import StepDict + +# Unified thread model +class Thread(SQLModel, table=True): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + created_at: str = "" + name: Optional[str] = None + user_id: Optional[str] = None + user_identifier: Optional[str] = None + tags: Optional[List[str]] = Field(default_factory=list) + metadata: Optional[Dict] = None + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + # Private attributes for business logic (not persisted or serialized) + _steps: Optional[List] = None # You can specify List[Step] if imported + _elements: Optional[List] = None # You can specify List[Element] if imported + _runtime_state: dict = PrivateAttr(default_factory=dict) + + # Example business logic method + def add_tag(self, tag: str): + if tag not in self.tags: + self.tags.append(tag) + + def to_dict(self): + return self.model_dump(by_alias=True) + + @classmethod + def from_dict(cls, **kwargs) -> Self: + return cls.model_validate(**kwargs) + + +# Pagination and ThreadFilter +class Pagination(BaseModel): + first: int + cursor: Optional[str] = None + +class ThreadFilter(BaseModel): + feedback: Optional[int] = None + user_id: Optional[str] = None + search: Optional[str] = None + + +class PageInfo(BaseModel): + hasNextPage: bool + startCursor: Optional[str] + endCursor: Optional[str] + + def to_dict(self): + return self.model_dump() + + @classmethod + def from_dict(cls, page_info_dict: Dict) -> Self: + return cls(**page_info_dict) + +T = TypeVar("T", covariant=True) +class PaginatedResponse(BaseModel, Generic[T]): + page_info: PageInfo + data: List[T] + + def to_dict(self): + return self.model_dump() + + @classmethod + def from_dict( + cls, paginated_response_dict: Dict + ) -> "PaginatedResponse[T]": + page_info = PageInfo.from_dict(paginated_response_dict.get("page_info", {})) + data = [the_class.from_dict(d) for d in paginated_response_dict.get("data", [])] + return cls(page_info=page_info, data=data) + +# Thread requests/responses +class UpdateThreadRequest(BaseModel): + thread_id: str + name: str + +class DeleteThreadRequest(BaseModel): + thread_id: str + +class GetThreadsRequest(BaseModel): + pagination: Pagination + filter: ThreadFilter diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py new file mode 100644 index 0000000000..96179fc667 --- /dev/null +++ b/backend/chainlit/models/user.py @@ -0,0 +1,35 @@ +from typing import Dict, Optional, Literal +from sqlmodel import SQLModel, Field +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel +Provider = Literal[ + "credentials", + "header", + "github", + "google", + "azure-ad", + "azure-ad-hybrid", + "okta", + "auth0", + "descope", +] + +# Non-persisted user (for runtime/session use) +class User(BaseModel): + identifier: str + display_name: Optional[str] = None + metadata: Dict = Field(default_factory=dict) + +# Persisted user (for database use) +class PersistedUser(SQLModel, table=True): + id: str = Field(primary_key=True) + identifier: str + display_name: Optional[str] = None + metadata: Dict = Field(default_factory=dict) + created_at: Optional[str] = None + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) \ No newline at end of file diff --git a/backend/chainlit/types/__init__.py b/backend/chainlit/types/__init__.py deleted file mode 100644 index 065afca0e1..0000000000 --- a/backend/chainlit/types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .step import Step, StepService \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 91a8db4a55..38f85abd70 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -56,7 +56,8 @@ dependencies = [ "pyjwt>=2.8.0,<3.0.0", "audioop-lts>=0.2.1,<0.3.0; python_version>='3.13'", "pydantic-settings>=2.10.1", - "sqlmodel>=0.0.24" + "sqlmodel>=0.0.24", + "alembic>=1.16.5" ] [project.urls] diff --git a/backend/uv.lock b/backend/uv.lock index 9ca675ba59..9beeab9cf7 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -170,6 +170,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, ] +[[package]] +name = "alembic" +version = "1.16.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/ca/4dc52902cf3491892d464f5265a81e9dff094692c8a049a3ed6a05fe7ee8/alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e", size = 1969868, upload-time = "2025-08-27T18:02:05.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/4a/4c61d4c84cfd9befb6fa08a702535b27b21fff08c946bc2f6139decbf7f7/alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3", size = 247355, upload-time = "2025-08-27T18:02:07.37Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -683,6 +698,7 @@ version = "2.7.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, + { name = "alembic" }, { name = "asyncer" }, { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, { name = "click" }, @@ -753,6 +769,7 @@ tests = [ requires-dist = [ { name = "aiofiles", specifier = ">=23.1.0,<25.0.0" }, { name = "aiosqlite", marker = "extra == 'tests'", specifier = ">=0.20.0,<1.0.0" }, + { name = "alembic", specifier = ">=1.16.5" }, { name = "asyncer", specifier = ">=0.0.8,<0.1.0" }, { name = "asyncpg", marker = "extra == 'custom-data'", specifier = ">=0.30.0,<1.0.0" }, { name = "audioop-lts", marker = "python_full_version >= '3.13'", specifier = ">=0.2.1,<0.3.0" }, @@ -2501,6 +2518,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/50/c5ccd2a50daa0a10c7f3f7d4e6992392454198cd8a7d99fcb96cb60d0686/llama_parse-0.6.54-py3-none-any.whl", hash = "sha256:c66c8d51cf6f29a44eaa8595a595de5d2598afc86e5a33a4cebe5fe228036920", size = 4879, upload-time = "2025-08-01T20:09:22.651Z" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" From 7820c1d07095d18a6b65e9fcaa249be7ecf47ebd Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 14:32:52 -0500 Subject: [PATCH 3/9] remove duped alembic --- backend/alembic/README.md | 39 -------------- backend/alembic/env.py | 52 ------------------- .../alembic/versions/0001_create_tables.py | 20 ------- 3 files changed, 111 deletions(-) delete mode 100644 backend/alembic/README.md delete mode 100644 backend/alembic/env.py delete mode 100644 backend/alembic/versions/0001_create_tables.py diff --git a/backend/alembic/README.md b/backend/alembic/README.md deleted file mode 100644 index 5f504a1b51..0000000000 --- a/backend/alembic/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Alembic Migrations for Chainlit SQLModelDataLayer - -This directory contains Alembic migration scripts for the SQLModel-based data layer. - -## Best Practices - -- **Do not use `SQLModel.metadata.create_all()` in production.** -- Always manage schema changes with Alembic migrations. -- Keep migration scripts in version control. -- Run migrations before starting the app, or enable auto-migration with `CHAINLIT_AUTO_MIGRATE=true`. - -## Usage - -1. **Configure your database URL** in `alembic.ini`: - ```ini - sqlalchemy.url = - ``` - -2. **Autogenerate a migration** (after changing models): - ```bash - alembic revision --autogenerate -m "Initial tables" - ``` - -3. **Apply migrations**: - ```bash - alembic upgrade head - ``` - -## Initial Migration - -The first migration should create all tables defined in `chainlit.models`. - -## env.py - -Alembic is configured to use `SQLModel.metadata` from `chainlit.models`. - ---- - -For more details, see the [Alembic documentation](https://alembic.sqlalchemy.org/en/latest/). diff --git a/backend/alembic/env.py b/backend/alembic/env.py deleted file mode 100644 index 0435faa2af..0000000000 --- a/backend/alembic/env.py +++ /dev/null @@ -1,52 +0,0 @@ -import sys -import os -from logging.config import fileConfig -from sqlalchemy import engine_from_config, pool -from alembic import context - -# Add the parent directory to sys.path to import models -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from chainlit.models import SQLModel - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -fileConfig(config.config_file_name) - -target_metadata = SQLModel.metadata - -def run_migrations_offline(): - """ - Run migrations in 'offline' mode. - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"} - ) - - with context.begin_transaction(): - context.run_migrations() - -def run_migrations_online(): - """ - Run migrations in 'online' mode. - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/backend/alembic/versions/0001_create_tables.py b/backend/alembic/versions/0001_create_tables.py deleted file mode 100644 index c04927f893..0000000000 --- a/backend/alembic/versions/0001_create_tables.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Initial migration: create all tables for SQLModelDataLayer -""" -from alembic import op -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision = '0001_create_tables' -down_revision = None -branch_labels = None -depends_on = None - -def upgrade(): - # This will be auto-generated by Alembic's autogenerate feature, - # but for best practice, you can start with an empty upgrade and run autogenerate. - pass - -def downgrade(): - # Drop all tables (if needed) - pass From 0b34727fa8370af830f35472616c038bfd18e7d1 Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:11:36 -0500 Subject: [PATCH 4/9] metadata issue --- backend/chainlit/models/element.py | 227 ++++++++++++++++------------ backend/chainlit/models/feedback.py | 14 +- backend/chainlit/models/step.py | 58 ++++--- backend/chainlit/models/thread.py | 5 +- backend/chainlit/models/user.py | 7 +- 5 files changed, 183 insertions(+), 128 deletions(-) diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py index 0d31ae271b..0718b61a56 100644 --- a/backend/chainlit/models/element.py +++ b/backend/chainlit/models/element.py @@ -1,7 +1,7 @@ -from typing import Optional, Dict, List, Literal, Union, ClassVar, TypeVar, Any, cast +from typing import Optional, Dict, List, Literal, Union, ClassVar, TypeVar, Any, cast, get_args from sqlmodel import SQLModel, Field import uuid -from pydantic import ConfigDict +from pydantic import ConfigDict, field_validator from pydantic.alias_generators import to_camel from syncer import asyncio import filetype @@ -18,16 +18,16 @@ } ElementType = Literal[ - "image", - "text", - "pdf", - "tasklist", - "audio", - "video", - "file", - "plotly", - "dataframe", - "custom", + "image", + "text", + "pdf", + "tasklist", + "audio", + "video", + "file", + "plotly", + "dataframe", + "custom", ] ElementDisplay = Literal["inline", "side", "page"] ElementSize = Literal["small", "medium", "large"] @@ -35,24 +35,45 @@ class Element(SQLModel, table=True): id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) thread_id: Optional[str] = None - type: ElementType + type: str = Field(..., nullable=False) name: str = "" url: Optional[str] = None path: Optional[str] = None object_key: Optional[str] = None chainlit_key: Optional[str] = None - display: ElementDisplay - size: Optional[ElementSize] = None + display: str = Field(..., nullable=False) + size: Optional[str] = None language: Optional[str] = None mime: Optional[str] = None for_id: Optional[str] = None # Add other common fields as needed - + model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) + @field_validator("type", mode="before") + def validate_type(cls, v): + allowed = list(get_args(ElementType)) + if v not in allowed: + raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") + return v + + @field_validator("display", mode="before") + def validate_display(cls, v): + allowed = list(get_args(ElementDisplay)) + if v not in allowed: + raise ValueError(f"Invalid display: {v}. Must be one of: {allowed}") + return v + + @field_validator("size", mode="before") + def validate_size(cls, v): + allowed = list(get_args(ElementSize)) + if v not in allowed: + raise ValueError(f"Invalid size: {v}. Must be one of: {allowed}") + return v + def to_dict(self): return self.model_dump(by_alias=True) @@ -79,7 +100,7 @@ def from_dict(cls, **kwargs): return Text(**kwargs) else: return File(**kwargs) - + @classmethod def infer_type_from_mime(cls, mime_type: str): """Infer the element type from a mime type. Useful to know which element to instantiate from a file upload.""" @@ -154,51 +175,55 @@ async def send(self, for_id: str, persist=True): # Subclasses for runtime logic (not persisted, but can be instantiated from Element) class Image(Element): - type: ClassVar[ElementType] = "image" - size: Optional[str] = "medium" + def __init__(self, *args, **kwargs): + kwargs.setdefault("type", "image") + kwargs.setdefault("size", "medium") + super().__init__(*args, **kwargs) class Text(Element): - type: ElementType = "text" - language: Optional[str] = None + def __init__(self, *args, **kwargs): + kwargs.setdefault("type", "text") + kwargs.setdefault("language", None) + super().__init__(*args, **kwargs) -class Pdf(Element): - type: ElementType = "pdf" - mime: str = "application/pdf" - page: Optional[int] = None +class Pdf(Element): + def __init__(self, *args, **kwargs): + kwargs.setdefault("type", "pdf") + kwargs.setdefault("mime", "application/pdf") + super().__init__(*args, **kwargs) + + page: Optional[int] = None class Pyplot(Element): - """Useful to send a pyplot to the UI.""" - - # We reuse the frontend image element to display the chart - type: ClassVar[ElementType] = "image" - - size: ElementSize = "medium" - # The type is set to Any because the figure is not serializable - # and its actual type is checked in __post_init__. - figure: Any = None - content: bytes = b"" - - def __post_init__(self) -> None: - from matplotlib.figure import Figure - from io import BytesIO + """Useful to send a pyplot to the UI.""" + def __init__(self, *args, figure=None, content=None, **kwargs): + kwargs.setdefault("type", "image") + kwargs.setdefault("size", "medium") + super().__init__(*args, **kwargs) + self.figure = figure + self.content = content if content is not None else b"" - if not isinstance(self.figure, Figure): - raise TypeError("figure must be a matplotlib.figure.Figure") - - image = BytesIO() - self.figure.savefig( - image, dpi=200, bbox_inches="tight", backend="Agg", format="png" - ) - self.content = image.getvalue() - - super().__post_init__() + def __post_init__(self) -> None: + if hasattr(self, "figure") and self.figure is not None: + from matplotlib.figure import Figure + from io import BytesIO + if not isinstance(self.figure, Figure): + raise TypeError("figure must be a matplotlib.figure.Figure") + image = BytesIO() + self.figure.savefig( + image, dpi=200, bbox_inches="tight", backend="Agg", format="png" + ) + self.content = image.getvalue() + super().__post_init__() class TaskList(Element): - type: ElementType = "tasklist" - tasks: List = Field(default_factory=list) - status: str = "Ready" - name: str = "tasklist" - content: str = "dummy content to pass validation" + def __init__(self, *args, tasks=None, status="Ready", name="tasklist", content="dummy content to pass validation", **kwargs): + kwargs.setdefault("type", "tasklist") + super().__init__(*args, **kwargs) + self.tasks = tasks if tasks is not None else [] + self.status = status + self.name = name + self.content = content def __post_init__(self) -> None: super().__post_init__() @@ -220,7 +245,6 @@ async def preprocess_content(self): {"title": task.title, "status": task.status.value, "forId": task.forId} for task in self.tasks ] - # store stringified json in content so that it's correctly stored in the database self.content = json.dumps( { @@ -232,65 +256,70 @@ async def preprocess_content(self): ) class Audio(Element): - type: ClassVar[ElementType] = "audio" - auto_play: bool = False + def __init__(self, *args, auto_play=False, **kwargs): + kwargs.setdefault("type", "audio") + super().__init__(*args, **kwargs) + self.auto_play = auto_play class Video(Element): - type: ClassVar[ElementType] = "video" - size: ElementSize = "medium" - # Override settings for each type of player in ReactPlayer - # https://github.com/cookpete/react-player?tab=readme-ov-file#config-prop - player_config: Optional[dict] = None + def __init__(self, *args, player_config=None, **kwargs): + kwargs.setdefault("type", "video") + kwargs.setdefault("size", "medium") + super().__init__(*args, **kwargs) + self.player_config = player_config class File(Element): - type: ElementType = "file" + def __init__(self, *args, **kwargs): + kwargs.setdefault("type", "file") + super().__init__(*args, **kwargs) class Plotly(Element): - type: ElementType = "plotly" - size: Optional[str] = "medium" - figure: Optional[Any] = None - content: str = "" + def __init__(self, *args, figure=None, content="", **kwargs): + kwargs.setdefault("type", "plotly") + kwargs.setdefault("size", "medium") + super().__init__(*args, **kwargs) + self.figure = figure + self.content = content def __post_init__(self) -> None: - from plotly import graph_objects as go, io as pio - - if not isinstance(self.figure, go.Figure): - raise TypeError("figure must be a plotly.graph_objects.Figure") - - self.figure.layout.autosize = True - self.figure.layout.width = None - self.figure.layout.height = None - self.content = pio.to_json(self.figure, validate=True) - self.mime = "application/json" - + if hasattr(self, "figure") and self.figure is not None: + from plotly import graph_objects as go, io as pio + if not isinstance(self.figure, go.Figure): + raise TypeError("figure must be a plotly.graph_objects.Figure") + self.figure.layout.autosize = True + self.figure.layout.width = None + self.figure.layout.height = None + self.content = pio.to_json(self.figure, validate=True) + self.mime = "application/json" super().__post_init__() class Dataframe(Element): - type: ElementType = "dataframe" - size: Optional[str] = "large" - data: Any = None # The type is Any because it is checked in __post_init__. + def __init__(self, *args, data=None, **kwargs): + kwargs.setdefault("type", "dataframe") + kwargs.setdefault("size", "large") + super().__init__(*args, **kwargs) + self.data = data def __post_init__(self) -> None: - """Ensures the data is a pandas DataFrame and converts it to JSON.""" - from pandas import DataFrame - - if not isinstance(self.data, DataFrame): - raise TypeError("data must be a pandas.DataFrame") - - self.content = self.data.to_json(orient="split", date_format="iso") + if hasattr(self, "data") and self.data is not None: + from pandas import DataFrame + if not isinstance(self.data, DataFrame): + raise TypeError("data must be a pandas.DataFrame") + self.content = self.data.to_json(orient="split", date_format="iso") super().__post_init__() class CustomElement(Element): - """Useful to send a custom element to the UI.""" - - type: ClassVar[ElementType] = "custom" - mime: str = "application/json" - props: Dict = Field(default_factory=dict) + """Useful to send a custom element to the UI.""" + def __init__(self, *args, mime="application/json", props=None, **kwargs): + kwargs.setdefault("type", "custom") + super().__init__(*args, **kwargs) + self.mime = mime + self.props = props if props is not None else {} - def __post_init__(self) -> None: - self.content = json.dumps(self.props) - super().__post_init__() - self.updatable = True + def __post_init__(self) -> None: + self.content = json.dumps(self.props) + super().__post_init__() + self.updatable = True - async def update(self): - await super().send(self.for_id) \ No newline at end of file + async def update(self): + await super().send(self.for_id) \ No newline at end of file diff --git a/backend/chainlit/models/feedback.py b/backend/chainlit/models/feedback.py index eee883084d..d56b78f4f0 100644 --- a/backend/chainlit/models/feedback.py +++ b/backend/chainlit/models/feedback.py @@ -1,7 +1,6 @@ -from typing import Dict, Optional, Literal +from typing import Dict, Optional, Literal, get_args from sqlmodel import SQLModel, Field -from pydantic import BaseModel -from pydantic import ConfigDict +from pydantic import BaseModel, field_validator, ConfigDict, conint from pydantic.alias_generators import to_camel FeedbackStrategy = Literal["BINARY"] @@ -9,7 +8,7 @@ class Feedback(SQLModel, table=True): id: Optional[str] = Field(default=None, primary_key=True) for_id: str - value: Literal[0, 1] + value: int = Field(..., ge=0, le=1) thread_id: Optional[str] = None comment: Optional[str] = None @@ -17,6 +16,13 @@ class Feedback(SQLModel, table=True): alias_generator=to_camel, populate_by_name=True, ) + + @field_validator("value", mode="before") + def validate_type(cls, v): + allowed = [0, 1] + if v not in allowed: + raise ValueError(f"Invalid value: {v}. Must be one of: {allowed}") + return v def to_dict(self): data = self.model_dump(by_alias=True) diff --git a/backend/chainlit/models/step.py b/backend/chainlit/models/step.py index e064c83ce7..e936e6bf66 100644 --- a/backend/chainlit/models/step.py +++ b/backend/chainlit/models/step.py @@ -5,10 +5,13 @@ import uuid from copy import deepcopy from functools import wraps -from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal +from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal, Any, get_args from sqlmodel import SQLModel, Field +from sqlalchemy import Column, JSON +from sqlalchemy.dialects.postgresql import JSONB from pydantic import PrivateAttr +from pydantic import field_validator # If you want to keep compatibility with literalai types, import as needed from literalai import BaseGeneration from pydantic import ConfigDict @@ -29,28 +32,34 @@ StepType = Union[TrueStepType, MessageStepType] - class Step(SQLModel, table=True): - id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - name: str = "" - type: str = "undefined" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + name: str = Field(..., nullable=False) + type: str = Field(..., nullable=False) + thread_id: str = Field(..., foreign_key="thread.id", nullable=False) parent_id: Optional[str] = Field(default=None, foreign_key="step.id") - thread_id: Optional[str] = None - streaming: bool = False - persisted: bool = False - show_input: Union[bool, str] = "json" - is_error: Optional[bool] = False - metadata: Dict = Field(default_factory=dict) - tags: Optional[List[str]] = None - created_at: Optional[str] = None - start: Optional[str] = None - end: Optional[str] = None - generation: Optional[BaseGeneration] = None - language: Optional[str] = None - default_open: Optional[bool] = False - input: Optional[str] = "" - output: Optional[str] = "" - + disable_feedback: bool = Field(default=False, nullable=False) + streaming: bool = Field(default=False, nullable=False) + wait_for_answer: Optional[bool] = Field(default=None) + is_error: Optional[bool] = Field(default=None) + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') + input: Optional[str] = Field(default=None) + output: Optional[str] = Field(default=None) + created_at: Optional[str] = Field(default=None) + start: Optional[str] = Field(default=None) + end: Optional[str] = Field(default=None) + generation: Optional[dict] = Field(default_factory=dict, sa_column=Column('generation', JSON), alias='generation') + show_input: str = Field(default="json") + language: Optional[str] = Field(default=None) + indent: Optional[int] = Field(default=None) + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + default_open: Optional[bool] = Field(default=False) + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + # TODO define relationship with Element # elements: List[Element] = Relationship(back_populates="step") # thread: Optional[Thread] = Relationship(back_populates="steps") @@ -66,6 +75,13 @@ class Step(SQLModel, table=True): _input: str = PrivateAttr(default="") _output: str = PrivateAttr(default="") + @field_validator("type", mode="before") + def validate_type(cls, v): + allowed = [v for arg in get_args(StepType) for v in (get_args(arg) if hasattr(arg, "__args__") else [arg])] + if v not in allowed: + raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") + return v + @property def input_value(self): return self._input diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py index 6b05c2b8d8..74953abb4a 100644 --- a/backend/chainlit/models/thread.py +++ b/backend/chainlit/models/thread.py @@ -5,6 +5,7 @@ import uuid from pydantic import ConfigDict from pydantic.alias_generators import to_camel +from sqlalchemy import Column, JSON if TYPE_CHECKING: from chainlit.element import ElementDict @@ -17,8 +18,8 @@ class Thread(SQLModel, table=True): name: Optional[str] = None user_id: Optional[str] = None user_identifier: Optional[str] = None - tags: Optional[List[str]] = Field(default_factory=list) - metadata: Optional[Dict] = None + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') model_config = ConfigDict( alias_generator=to_camel, diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py index 96179fc667..6d14d253b2 100644 --- a/backend/chainlit/models/user.py +++ b/backend/chainlit/models/user.py @@ -3,6 +3,9 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic.alias_generators import to_camel +from sqlalchemy import Column, JSON +import uuid + Provider = Literal[ "credentials", "header", @@ -23,10 +26,10 @@ class User(BaseModel): # Persisted user (for database use) class PersistedUser(SQLModel, table=True): - id: str = Field(primary_key=True) + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) identifier: str display_name: Optional[str] = None - metadata: Dict = Field(default_factory=dict) + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') created_at: Optional[str] = None model_config = ConfigDict( From 220fbb68676e90b8b0ea8b06ac4fea0d1d43e1e4 Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:25:29 -0500 Subject: [PATCH 5/9] more _metadata workaround to make it work --- backend/chainlit/models/step.py | 8 +++----- backend/chainlit/models/thread.py | 2 +- backend/chainlit/models/user.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/backend/chainlit/models/step.py b/backend/chainlit/models/step.py index e936e6bf66..0c2a48781b 100644 --- a/backend/chainlit/models/step.py +++ b/backend/chainlit/models/step.py @@ -12,7 +12,6 @@ from sqlalchemy.dialects.postgresql import JSONB from pydantic import PrivateAttr from pydantic import field_validator -# If you want to keep compatibility with literalai types, import as needed from literalai import BaseGeneration from pydantic import ConfigDict from pydantic.alias_generators import to_camel @@ -23,6 +22,7 @@ from chainlit.logger import logger from chainlit.types import FeedbackDict from chainlit.utils import utc_now +from chainlit.context import context TrueStepType = Literal[ "run", "tool", "llm", "embedding", "retrieval", "rerank", "undefined" @@ -42,7 +42,7 @@ class Step(SQLModel, table=True): streaming: bool = Field(default=False, nullable=False) wait_for_answer: Optional[bool] = Field(default=None) is_error: Optional[bool] = Field(default=None) - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) input: Optional[str] = Field(default=None) output: Optional[str] = Field(default=None) created_at: Optional[str] = Field(default=None) @@ -257,16 +257,14 @@ async def send(self): tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] await asyncio.gather(*tasks) - from chainlit.context import check_add_step_in_cot, stub_step if not check_add_step_in_cot(self): - await context.emitter.send_step(stub_step(self)) + await context.emitter.send_step(self.to_dict()) else: await context.emitter.send_step(step_dict) return self async def stream_token(self, token: str, is_sequence=False, is_input=False): - from chainlit.context import context, check_add_step_in_cot, stub_step if not token: return diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py index 74953abb4a..0a80d8943a 100644 --- a/backend/chainlit/models/thread.py +++ b/backend/chainlit/models/thread.py @@ -19,7 +19,7 @@ class Thread(SQLModel, table=True): user_id: Optional[str] = None user_identifier: Optional[str] = None tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) model_config = ConfigDict( alias_generator=to_camel, diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py index 6d14d253b2..49fff98c10 100644 --- a/backend/chainlit/models/user.py +++ b/backend/chainlit/models/user.py @@ -29,7 +29,7 @@ class PersistedUser(SQLModel, table=True): id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) identifier: str display_name: Optional[str] = None - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata') + metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) created_at: Optional[str] = None model_config = ConfigDict( From 39c87c8a343011cff7bdfdef7b2f880db9f6482a Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Fri, 29 Aug 2025 17:00:07 -0500 Subject: [PATCH 6/9] updated elements --- backend/chainlit/models/__init__.py | 2 +- backend/chainlit/models/element.py | 123 ++++++++++++---------------- 2 files changed, 52 insertions(+), 73 deletions(-) diff --git a/backend/chainlit/models/__init__.py b/backend/chainlit/models/__init__.py index e2b017df11..d491140af3 100644 --- a/backend/chainlit/models/__init__.py +++ b/backend/chainlit/models/__init__.py @@ -2,4 +2,4 @@ from .user import User, PersistedUser from .thread import Thread from .feedback import Feedback, UpdateFeedbackRequest, DeleteFeedbackRequest -from .element import Element, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file +from .element import Element #, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py index 0718b61a56..d3a29ec2b6 100644 --- a/backend/chainlit/models/element.py +++ b/backend/chainlit/models/element.py @@ -10,6 +10,7 @@ from chainlit.logger import logger from chainlit.element import Task, TaskStatus import json +from sqlalchemy import Column, JSON mime_types = { "text": "text/plain", @@ -46,7 +47,10 @@ class Element(SQLModel, table=True): language: Optional[str] = None mime: Optional[str] = None for_id: Optional[str] = None - # Add other common fields as needed + page: Optional[int] = None + props: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) + auto_play: Optional[bool] = None + player_config: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) model_config = ConfigDict( alias_generator=to_camel, @@ -81,25 +85,25 @@ def to_dict(self): def from_dict(cls, **kwargs): type_ = kwargs.get("type", "file") if type_ == "image": - return Image(**kwargs) + return Image.model_validate(**kwargs) elif type_ == "audio": - return Audio(**kwargs) + return Audio.model_validate(**kwargs) elif type_ == "video": - return Video(**kwargs) + return Video.model_validate(**kwargs) elif type_ == "plotly": - return Plotly(**kwargs) + return Plotly.model_validate(**kwargs) elif type_ == "custom": - return CustomElement(**kwargs) + return CustomElement.model_validate(**kwargs) elif type_ == "pdf": - return Pdf(**kwargs) + return Pdf.model_validate(**kwargs) elif type_ == "tasklist": - return TaskList(**kwargs) + return TaskList.model_validate(**kwargs) elif type_ == "dataframe": - return Dataframe(**kwargs) + return Dataframe.model_validate(**kwargs) elif type_ == "text": - return Text(**kwargs) + return Text.model_validate(**kwargs) else: - return File(**kwargs) + return File.model_validate(**kwargs) @classmethod def infer_type_from_mime(cls, mime_type: str): @@ -175,34 +179,25 @@ async def send(self, for_id: str, persist=True): # Subclasses for runtime logic (not persisted, but can be instantiated from Element) class Image(Element): - def __init__(self, *args, **kwargs): - kwargs.setdefault("type", "image") - kwargs.setdefault("size", "medium") - super().__init__(*args, **kwargs) + type: str = "image" + size: str = "medium" class Text(Element): - def __init__(self, *args, **kwargs): - kwargs.setdefault("type", "text") - kwargs.setdefault("language", None) - super().__init__(*args, **kwargs) + type: str = "text" + language: Optional[str] = None -class Pdf(Element): - def __init__(self, *args, **kwargs): - kwargs.setdefault("type", "pdf") - kwargs.setdefault("mime", "application/pdf") - super().__init__(*args, **kwargs) - - page: Optional[int] = None +class Pdf(Element): + type: str = "pdf" + mime: str = "application/pdf" + page: Optional[int] = None class Pyplot(Element): """Useful to send a pyplot to the UI.""" - def __init__(self, *args, figure=None, content=None, **kwargs): - kwargs.setdefault("type", "image") - kwargs.setdefault("size", "medium") - super().__init__(*args, **kwargs) - self.figure = figure - self.content = content if content is not None else b"" - + type: str = "image" + size: str = "medium" + figure: Any = None + content: bytes = b"" + def __post_init__(self) -> None: if hasattr(self, "figure") and self.figure is not None: from matplotlib.figure import Figure @@ -217,14 +212,12 @@ def __post_init__(self) -> None: super().__post_init__() class TaskList(Element): - def __init__(self, *args, tasks=None, status="Ready", name="tasklist", content="dummy content to pass validation", **kwargs): - kwargs.setdefault("type", "tasklist") - super().__init__(*args, **kwargs) - self.tasks = tasks if tasks is not None else [] - self.status = status - self.name = name - self.content = content - + type: str = "tasklist" + tasks: list = [] + status: str = "Ready" + name: str = "tasklist" + content: str = "dummy content to pass validation" + def __post_init__(self) -> None: super().__post_init__() self.updatable = True @@ -256,31 +249,22 @@ async def preprocess_content(self): ) class Audio(Element): - def __init__(self, *args, auto_play=False, **kwargs): - kwargs.setdefault("type", "audio") - super().__init__(*args, **kwargs) - self.auto_play = auto_play + type: str = "audio" + auto_play: bool = False class Video(Element): - def __init__(self, *args, player_config=None, **kwargs): - kwargs.setdefault("type", "video") - kwargs.setdefault("size", "medium") - super().__init__(*args, **kwargs) - self.player_config = player_config + type: str = "video" + size: str = "medium" class File(Element): - def __init__(self, *args, **kwargs): - kwargs.setdefault("type", "file") - super().__init__(*args, **kwargs) + type: str = "file" class Plotly(Element): - def __init__(self, *args, figure=None, content="", **kwargs): - kwargs.setdefault("type", "plotly") - kwargs.setdefault("size", "medium") - super().__init__(*args, **kwargs) - self.figure = figure - self.content = content - + type: str = "plotly" + size: str = "medium" + figure: Any = None + content: str = "" + def __post_init__(self) -> None: if hasattr(self, "figure") and self.figure is not None: from plotly import graph_objects as go, io as pio @@ -294,12 +278,10 @@ def __post_init__(self) -> None: super().__post_init__() class Dataframe(Element): - def __init__(self, *args, data=None, **kwargs): - kwargs.setdefault("type", "dataframe") - kwargs.setdefault("size", "large") - super().__init__(*args, **kwargs) - self.data = data - + type: str = "dataframe" + size: str = "large" + data: Any = None + def __post_init__(self) -> None: if hasattr(self, "data") and self.data is not None: from pandas import DataFrame @@ -310,12 +292,9 @@ def __post_init__(self) -> None: class CustomElement(Element): """Useful to send a custom element to the UI.""" - def __init__(self, *args, mime="application/json", props=None, **kwargs): - kwargs.setdefault("type", "custom") - super().__init__(*args, **kwargs) - self.mime = mime - self.props = props if props is not None else {} - + type: str = "custom" + mime: str = "application/json" + def __post_init__(self) -> None: self.content = json.dumps(self.props) super().__post_init__() From 0379fca08edfbfab8185e6a84af58b822cfdf65b Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Wed, 3 Sep 2025 07:09:23 -0500 Subject: [PATCH 7/9] add __tablename__ --- backend/chainlit/cli/__init__.py | 2 +- .../data/{sql_model.py => sql_data_layer.py} | 152 +++++++++-------- backend/chainlit/models/__init__.py | 2 +- backend/chainlit/models/element.py | 156 ++++++++++-------- backend/chainlit/models/feedback.py | 8 +- backend/chainlit/models/step.py | 8 +- backend/chainlit/models/thread.py | 6 +- backend/chainlit/models/user.py | 5 +- 8 files changed, 195 insertions(+), 144 deletions(-) rename backend/chainlit/data/{sql_model.py => sql_data_layer.py} (77%) diff --git a/backend/chainlit/cli/__init__.py b/backend/chainlit/cli/__init__.py index ed3b122c2b..aab637a78d 100644 --- a/backend/chainlit/cli/__init__.py +++ b/backend/chainlit/cli/__init__.py @@ -90,7 +90,7 @@ def auto_run_alembic_upgrade(): if data_layer_func: try: dl_instance = data_layer_func() - from chainlit.data.sql_model import SQLModelDataLayer + from backend.chainlit.data.sql_data_layer import SQLModelDataLayer if isinstance(dl_instance, SQLModelDataLayer): # Get current version try: diff --git a/backend/chainlit/data/sql_model.py b/backend/chainlit/data/sql_data_layer.py similarity index 77% rename from backend/chainlit/data/sql_model.py rename to backend/chainlit/data/sql_data_layer.py index 978ad3c1df..4e00c56f73 100644 --- a/backend/chainlit/data/sql_model.py +++ b/backend/chainlit/data/sql_data_layer.py @@ -5,6 +5,8 @@ from chainlit.data.storage_clients.base import BaseStorageClient from chainlit.data.utils import queue_until_user_message from typing import Optional, Any, Dict +from datetime import datetime +from pydantic import ValidationError from chainlit.models import PersistedUser, User, Feedback, Thread, Element, Step import json import ssl @@ -16,8 +18,22 @@ ThreadFilter, PageInfo ) +from sqlalchemy.engine import make_url +from sqlalchemy.pool import NullPool +from sqlalchemy import event -class SQLModelDataLayer(BaseDataLayer): +ALLOWED_ASYNC_DRIVERS = { + "postgresql+asyncpg", + "postgresql+psycopg", # psycopg3 async + "sqlite+aiosqlite", + "mysql+aiomysql", + "mysql+asyncmy", + "mariadb+aiomysql", + "mariadb+asyncmy", + "mssql+aioodbc", +} + +class SQLDataLayer(BaseDataLayer): def __init__( self, conninfo: str, @@ -29,34 +45,38 @@ def __init__( ): self._conninfo = conninfo self.user_thread_limit = user_thread_limit - self.show_logger = show_logger - if connect_args is None: - connect_args = {} + self.show_logger = bool(show_logger) + + connect_args = dict(connect_args or {}) + + # Validate async driver and prepare per-dialect settings + url = make_url(self._conninfo) + driver = url.drivername # e.g., "postgresql+asyncpg" + backend = url.get_backend_name() # e.g., "postgresql" + if driver not in ALLOWED_ASYNC_DRIVERS: + raise ValueError(f"Connection URL must use an async driver. Got '{driver}'. Use one of: {ALLOWED_ASYNC_DRIVERS}") + if ssl_require: # Create an SSL context to require an SSL connection ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - connect_args["ssl"] = ssl_context - self.engine = AsyncEngine( - create_async_engine( + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + connect_args.setdefault("ssl", ssl_context) + self.engine: AsyncEngine = create_async_engine( self._conninfo, connect_args=connect_args, echo=self.show_logger, ) - ) self.async_session = async_sessionmaker( bind=self.engine, expire_on_commit=False, class_=AsyncSession ) if storage_provider: self.storage_provider: Optional[BaseStorageClient] = storage_provider if self.show_logger: - logger.info("SQLModel storage client initialized") + logger.info("SQLDataLayer storage client initialized") else: self.storage_provider = None - logger.warning( - "SQLModel storage client is not initialized and elements will not be persisted!" - ) + logger.warning("SQLDataLayer storage client is not initialized and elements will not be persisted!") async def init_db(self): """ @@ -67,59 +87,62 @@ async def init_db(self): async with self.engine.begin() as conn: # await conn.run_sync(SQLModel.metadata.drop_all) # Uncomment to drop tables await conn.run_sync(SQLModel.metadata.create_all) + + async def aclose(self) -> None: + await self.engine.dispose() async def create_user(self, user: User) -> Optional[PersistedUser]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == user.identifier)) existing = result.scalar_one_or_none() - if not existing: - db_user = PersistedUser(identifier=user.identifier, metadata=user.metadata) - session.add(db_user) - await session.commit() - await session.refresh(db_user) - return PersistedUser(identifier=db_user.identifier, metadata=db_user.metadata) - return PersistedUser(identifier=existing.identifier, metadata=existing.metadata) + if existing: + return existing + db_user = PersistedUser( + identifier=user.identifier, + metadata=user.metadata, + ) + session.add(db_user) + return db_user async def get_user(self, identifier: str) -> Optional[PersistedUser]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) user = result.scalar_one_or_none() - if user: - return PersistedUser(identifier=user.identifier, metadata=user.metadata) - return None + return user async def update_user(self, identifier: str, metadata: Optional[dict] = None) -> Optional[PersistedUser]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) user = result.scalar_one_or_none() if user: if metadata is not None: user.metadata = metadata - await session.commit() await session.refresh(user) return PersistedUser(identifier=user.identifier, metadata=user.metadata) return None async def delete_user(self, identifier: str) -> bool: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) user = result.scalar_one_or_none() if user: await session.delete(user) - await session.commit() return True return False async def create_thread(self, thread_data: dict) -> Optional[Dict]: - async with self.async_session() as session: - db_thread = Thread.model_validate(thread_data) - session.add(db_thread) - await session.commit() - await session.refresh(db_thread) - return db_thread.to_dict() + try: + thread = Thread.model_validate(thread_data) + except ValidationError as e: + logger.error(f"Thread data validation error: {e}") + return None + async with self.async_session.begin() as session: + session.add(thread) + await session.refresh(thread) + return thread.to_dict() async def get_thread(self, thread_id: str) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Thread).where(Thread.id == thread_id)) thread = result.scalar_one_or_none() if thread: @@ -127,7 +150,7 @@ async def get_thread(self, thread_id: str) -> Optional[Dict]: return None async def get_thread_author(self, thread_id: str) -> str: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Thread).where(Thread.id == thread_id)) thread: Thread = result.scalar_one_or_none() if thread and thread.user_identifier: @@ -135,38 +158,39 @@ async def get_thread_author(self, thread_id: str) -> str: return "" async def update_thread(self, thread_id: str, **kwargs) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Thread).where(Thread.id == thread_id)) thread = result.scalar_one_or_none() if thread: for k, v in kwargs.items(): setattr(thread, k, v) - await session.commit() await session.refresh(thread) return thread.to_dict() return None async def delete_thread(self, thread_id: str) -> bool: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Thread).where(Thread.id == thread_id)) thread = result.scalar_one_or_none() if thread: await session.delete(thread) - await session.commit() return True return False @queue_until_user_message() async def create_step(self, step_data: dict) -> Optional[Dict]: - async with self.async_session() as session: - db_step = Step.model_validate(step_data) - session.add(db_step) - await session.commit() - await session.refresh(db_step) - return db_step.to_dict() + try: + step = Step.model_validate(step_data) + except ValidationError as e: + logger.error(f"Thread data validation error: {e}") + return None + async with self.async_session.begin() as session: + session.add(step) + await session.refresh(step) + return step.to_dict() async def get_step(self, step_id: str) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Step).where(Step.id == step_id)) step = result.scalar_one_or_none() if step: @@ -175,25 +199,23 @@ async def get_step(self, step_id: str) -> Optional[Dict]: @queue_until_user_message() async def update_step(self, step_id: str, **kwargs) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Step).where(Step.id == step_id)) step = result.scalar_one_or_none() if step: for k, v in kwargs.items(): setattr(step, k, v) - await session.commit() await session.refresh(step) return step.to_dict() return None @queue_until_user_message() async def delete_step(self, step_id: str) -> bool: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Step).where(Step.id == step_id)) step = result.scalar_one_or_none() if step: await session.delete(step) - await session.commit() return True return False @@ -201,7 +223,7 @@ async def upsert_feedback(self, feedback: Feedback) -> str: feedback_id = feedback.id or str(uuid.uuid4()) feedback_dict = feedback.dict() feedback_dict["id"] = feedback_id - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) db_feedback = result.scalar_one_or_none() if db_feedback: @@ -210,12 +232,11 @@ async def upsert_feedback(self, feedback: Feedback) -> str: else: db_feedback = Feedback.model_validate(feedback_dict) session.add(db_feedback) - await session.commit() await session.refresh(db_feedback) return db_feedback.id async def get_feedback(self, feedback_id: str) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) feedback = result.scalar_one_or_none() if feedback: @@ -223,17 +244,16 @@ async def get_feedback(self, feedback_id: str) -> Optional[Dict]: return None async def delete_feedback(self, feedback_id: str) -> bool: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) feedback = result.scalar_one_or_none() if feedback: await session.delete(feedback) - await session.commit() return True return False async def get_element(self, thread_id: str, element_id: str) -> Optional[Dict]: - async with self.async_session() as session: + async with self.async_session.begin() as session: result = await session.execute( select(Element).where(Element.thread_id == thread_id, Element.id == element_id) ) @@ -252,10 +272,10 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional[Dict]: @queue_until_user_message() async def create_element(self, element: "Element"): if self.show_logger: - logger.info(f"SQLModel: create_element, element_id = {element.id}") + logger.info(f"SQLDataLayer: create_element, element_id = {element.id}") if not self.storage_provider: - logger.warning("SQLModel: create_element error. No blob_storage_client is configured!") + logger.warning("SQLDataLayer: create_element error. No blob_storage_client is configured!") return if not element.for_id: return @@ -299,19 +319,18 @@ async def create_element(self, element: "Element"): if "props" in element_dict_cleaned: element_dict_cleaned["props"] = json.dumps(element_dict_cleaned["props"]) - async with self.async_session() as session: + async with self.async_session.begin() as session: db_element = Element.model_validate(element_dict_cleaned) session.add(db_element) - await session.commit() await session.refresh(db_element) return db_element.to_dict() @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): if self.show_logger: - logger.info(f"SQLModel: delete_element, element_id={element_id}") + logger.info(f"SQLDataLayer: delete_element, element_id={element_id}") - async with self.async_session() as session: + async with self.async_session.begin() as session: query = select(Element).where(Element.id == element_id) if thread_id: query = query.where(Element.thread_id == thread_id) @@ -326,7 +345,6 @@ async def delete_element(self, element_id: str, thread_id: Optional[str] = None) await self.storage_provider.delete_file(object_key=element['objectKey']) if element: await session.delete(element) - await session.commit() async def build_debug_url(self) -> str: # Implement as needed, or return empty string for now @@ -336,7 +354,7 @@ async def list_threads( self, pagination: Pagination, filters: ThreadFilter ) -> PaginatedResponse[Dict]: # Fetch threads for a user, apply pagination and filters - async with self.async_session() as session: + async with self.async_session.begin() as session: if filters.userId: query = select(Thread).where(Thread.user_id == filters.userId) result = await session.execute(query) diff --git a/backend/chainlit/models/__init__.py b/backend/chainlit/models/__init__.py index d491140af3..e2b017df11 100644 --- a/backend/chainlit/models/__init__.py +++ b/backend/chainlit/models/__init__.py @@ -2,4 +2,4 @@ from .user import User, PersistedUser from .thread import Thread from .feedback import Feedback, UpdateFeedbackRequest, DeleteFeedbackRequest -from .element import Element #, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file +from .element import Element, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py index d3a29ec2b6..91f82df8b8 100644 --- a/backend/chainlit/models/element.py +++ b/backend/chainlit/models/element.py @@ -2,20 +2,23 @@ from sqlmodel import SQLModel, Field import uuid from pydantic import ConfigDict, field_validator +from pydantic import PrivateAttr from pydantic.alias_generators import to_camel -from syncer import asyncio +import asyncio import filetype from chainlit.context import context from chainlit.data import get_data_layer from chainlit.logger import logger from chainlit.element import Task, TaskStatus import json -from sqlalchemy import Column, JSON +from sqlalchemy import Column, JSON, ForeignKey, String + +APPLICATION_JSON = "application/json" mime_types = { - "text": "text/plain", - "tasklist": "application/json", - "plotly": "application/json", + "text": "text/plain", + "tasklist": APPLICATION_JSON, + "plotly": APPLICATION_JSON, } ElementType = Literal[ @@ -34,8 +37,10 @@ ElementSize = Literal["small", "medium", "large"] class Element(SQLModel, table=True): + __tablename__ = "elements" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - thread_id: Optional[str] = None + thread_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True)) type: str = Field(..., nullable=False) name: str = "" url: Optional[str] = None @@ -46,16 +51,26 @@ class Element(SQLModel, table=True): size: Optional[str] = None language: Optional[str] = None mime: Optional[str] = None - for_id: Optional[str] = None + for_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True)) page: Optional[int] = None props: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) auto_play: Optional[bool] = None player_config: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) + # Non-DB payload used by runtime logic (private attribute) + _content: Optional[Union[str, bytes]] = PrivateAttr(default=None) model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) + + @property + def content(self) -> Optional[Union[str, bytes]]: + return self._content + + @content.setter + def content(self, value: Optional[Union[str, bytes]]): + self._content = value @field_validator("type", mode="before") def validate_type(cls, v): @@ -73,6 +88,8 @@ def validate_display(cls, v): @field_validator("size", mode="before") def validate_size(cls, v): + if v is None: + return v allowed = list(get_args(ElementSize)) if v not in allowed: raise ValueError(f"Invalid size: {v}. Must be one of: {allowed}") @@ -83,27 +100,10 @@ def to_dict(self): @classmethod def from_dict(cls, **kwargs): + # Default to file if missing type_ = kwargs.get("type", "file") - if type_ == "image": - return Image.model_validate(**kwargs) - elif type_ == "audio": - return Audio.model_validate(**kwargs) - elif type_ == "video": - return Video.model_validate(**kwargs) - elif type_ == "plotly": - return Plotly.model_validate(**kwargs) - elif type_ == "custom": - return CustomElement.model_validate(**kwargs) - elif type_ == "pdf": - return Pdf.model_validate(**kwargs) - elif type_ == "tasklist": - return TaskList.model_validate(**kwargs) - elif type_ == "dataframe": - return Dataframe.model_validate(**kwargs) - elif type_ == "text": - return Text.model_validate(**kwargs) - else: - return File.model_validate(**kwargs) + model = TYPE_MAP.get(type_, File) + return model.model_validate(kwargs) @classmethod def infer_type_from_mime(cls, mime_type: str): @@ -123,15 +123,15 @@ def infer_type_from_mime(cls, mime_type: str): else: return "file" - async def _create(self, persist=True) -> bool: - if getattr(self, "persisted", False) and not getattr(self, "updatable", False): - return True + async def _create(self, persist=True) -> None: + was_persisted = bool(getattr(self, "persisted", False)) + if was_persisted and not getattr(self, "updatable", False): + return None data_layer = get_data_layer() if data_layer and persist: try: - import asyncio - task = asyncio.create_task(data_layer.create_element(self)) + self._bg_task = asyncio.create_task(data_layer.create_element(self)) except Exception as e: logger.error(f"Failed to create element: {e!s}") @@ -145,7 +145,7 @@ async def _create(self, persist=True) -> bool: self.chainlit_key = file_dict["id"] self.persisted = True - return True + return None async def remove(self): data_layer = get_data_layer() @@ -178,27 +178,30 @@ async def send(self, for_id: str, persist=True): ElementBased = TypeVar("ElementBased", bound=Element) # Subclasses for runtime logic (not persisted, but can be instantiated from Element) -class Image(Element): +class Image(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "image" size: str = "medium" -class Text(Element): +class Text(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "text" language: Optional[str] = None -class Pdf(Element): +class Pdf(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "pdf" mime: str = "application/pdf" page: Optional[int] = None -class Pyplot(Element): +class Pyplot(Element, table=False): + __tablename__: ClassVar[None] = None """Useful to send a pyplot to the UI.""" type: str = "image" size: str = "medium" - figure: Any = None - content: bytes = b"" + figure: Any = Field(default=None, exclude=True) - def __post_init__(self) -> None: + def model_post_init(self, __context) -> None: if hasattr(self, "figure") and self.figure is not None: from matplotlib.figure import Figure from io import BytesIO @@ -209,28 +212,28 @@ def __post_init__(self) -> None: image, dpi=200, bbox_inches="tight", backend="Agg", format="png" ) self.content = image.getvalue() - super().__post_init__() + super().model_post_init(__context) -class TaskList(Element): +class TaskList(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "tasklist" - tasks: list = [] + tasks: List[Task] = Field(default_factory=list, exclude=True) status: str = "Ready" name: str = "tasklist" - content: str = "dummy content to pass validation" - def __post_init__(self) -> None: - super().__post_init__() + def model_post_init(self, __context) -> None: + super().model_post_init(__context) self.updatable = True async def add_task(self, task: Task): self.tasks.append(task) async def update(self): - await self.send() + await self.send(for_id=self.for_id or "") - async def send(self): + async def send(self, for_id: str, persist: bool = True): await self.preprocess_content() - await super().send(for_id="") + await super().send(for_id=for_id, persist=persist) async def preprocess_content(self): # serialize enum @@ -248,24 +251,27 @@ async def preprocess_content(self): ensure_ascii=False, ) -class Audio(Element): +class Audio(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "audio" auto_play: bool = False -class Video(Element): +class Video(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "video" size: str = "medium" -class File(Element): +class File(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "file" -class Plotly(Element): +class Plotly(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "plotly" size: str = "medium" - figure: Any = None - content: str = "" + figure: Any = Field(default=None, exclude=True) - def __post_init__(self) -> None: + def model_post_init(self, __context) -> None: if hasattr(self, "figure") and self.figure is not None: from plotly import graph_objects as go, io as pio if not isinstance(self.figure, go.Figure): @@ -274,31 +280,47 @@ def __post_init__(self) -> None: self.figure.layout.width = None self.figure.layout.height = None self.content = pio.to_json(self.figure, validate=True) - self.mime = "application/json" - super().__post_init__() + self.mime = APPLICATION_JSON + super().model_post_init(__context) -class Dataframe(Element): +class Dataframe(Element, table=False): + __tablename__: ClassVar[None] = None type: str = "dataframe" size: str = "large" - data: Any = None + data: Any = Field(default=None, exclude=True) - def __post_init__(self) -> None: + def model_post_init(self, __context) -> None: if hasattr(self, "data") and self.data is not None: from pandas import DataFrame if not isinstance(self.data, DataFrame): raise TypeError("data must be a pandas.DataFrame") self.content = self.data.to_json(orient="split", date_format="iso") - super().__post_init__() + super().model_post_init(__context) -class CustomElement(Element): +class CustomElement(Element, table=False): + __tablename__: ClassVar[None] = None """Useful to send a custom element to the UI.""" type: str = "custom" - mime: str = "application/json" + mime: str = APPLICATION_JSON - def __post_init__(self) -> None: + def model_post_init(self, __context) -> None: self.content = json.dumps(self.props) - super().__post_init__() + super().model_post_init(__context) self.updatable = True async def update(self): - await super().send(self.for_id) \ No newline at end of file + await super().send(self.for_id) + +# Simple mapping for type discrimination (Pyplot shares "image", so not included) +TYPE_MAP: Dict[str, Any] = { + "image": Image, + "text": Text, + "pdf": Pdf, + "tasklist": TaskList, + "audio": Audio, + "video": Video, + "file": File, + "plotly": Plotly, + "dataframe": Dataframe, + "custom": CustomElement, +} \ No newline at end of file diff --git a/backend/chainlit/models/feedback.py b/backend/chainlit/models/feedback.py index d56b78f4f0..13290f1d75 100644 --- a/backend/chainlit/models/feedback.py +++ b/backend/chainlit/models/feedback.py @@ -2,12 +2,16 @@ from sqlmodel import SQLModel, Field from pydantic import BaseModel, field_validator, ConfigDict, conint from pydantic.alias_generators import to_camel +import uuid +from sqlalchemy import Column, ForeignKey, String FeedbackStrategy = Literal["BINARY"] class Feedback(SQLModel, table=True): - id: Optional[str] = Field(default=None, primary_key=True) - for_id: str + __tablename__ = "feedbacks" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + for_id: str = Field(sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"))) value: int = Field(..., ge=0, le=1) thread_id: Optional[str] = None comment: Optional[str] = None diff --git a/backend/chainlit/models/step.py b/backend/chainlit/models/step.py index 0c2a48781b..ee0079f877 100644 --- a/backend/chainlit/models/step.py +++ b/backend/chainlit/models/step.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal, Any, get_args from sqlmodel import SQLModel, Field -from sqlalchemy import Column, JSON +from sqlalchemy import Column, JSON, ForeignKey, String from sqlalchemy.dialects.postgresql import JSONB from pydantic import PrivateAttr from pydantic import field_validator @@ -33,11 +33,13 @@ StepType = Union[TrueStepType, MessageStepType] class Step(SQLModel, table=True): + __tablename__ = "steps" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) name: str = Field(..., nullable=False) type: str = Field(..., nullable=False) - thread_id: str = Field(..., foreign_key="thread.id", nullable=False) - parent_id: Optional[str] = Field(default=None, foreign_key="step.id") + thread_id: str = Field(..., sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=False)) + parent_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True)) disable_feedback: bool = Field(default=False, nullable=False) streaming: bool = Field(default=False, nullable=False) wait_for_answer: Optional[bool] = Field(default=None) diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py index 0a80d8943a..04741fe39e 100644 --- a/backend/chainlit/models/thread.py +++ b/backend/chainlit/models/thread.py @@ -5,7 +5,7 @@ import uuid from pydantic import ConfigDict from pydantic.alias_generators import to_camel -from sqlalchemy import Column, JSON +from sqlalchemy import Column, JSON, ForeignKey, String if TYPE_CHECKING: from chainlit.element import ElementDict @@ -13,10 +13,12 @@ # Unified thread model class Thread(SQLModel, table=True): + __tablename__ = "threads" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) created_at: str = "" name: Optional[str] = None - user_id: Optional[str] = None + user_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=True)) user_identifier: Optional[str] = None tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py index 49fff98c10..95d4ca3085 100644 --- a/backend/chainlit/models/user.py +++ b/backend/chainlit/models/user.py @@ -5,6 +5,7 @@ from pydantic.alias_generators import to_camel from sqlalchemy import Column, JSON import uuid +from chainlit.utils import utc_now Provider = Literal[ "credentials", @@ -26,11 +27,13 @@ class User(BaseModel): # Persisted user (for database use) class PersistedUser(SQLModel, table=True): + __tablename__ = "users" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) identifier: str display_name: Optional[str] = None metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) - created_at: Optional[str] = None + created_at: str = Field(default_factory=utc_now(), primary_key=True) model_config = ConfigDict( alias_generator=to_camel, From 254608d8de5553ac879f914b29226e724d52b35e Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Wed, 3 Sep 2025 07:33:45 -0500 Subject: [PATCH 8/9] new subclassing approach --- backend/chainlit/models/element.py | 304 +++++++++++++++++------------ 1 file changed, 184 insertions(+), 120 deletions(-) diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py index 91f82df8b8..b992204bf5 100644 --- a/backend/chainlit/models/element.py +++ b/backend/chainlit/models/element.py @@ -1,7 +1,8 @@ -from typing import Optional, Dict, List, Literal, Union, ClassVar, TypeVar, Any, cast, get_args +from typing import Optional, Dict, List, Union, ClassVar, TypeVar, Any, Literal, get_args from sqlmodel import SQLModel, Field import uuid -from pydantic import ConfigDict, field_validator +from pydantic import ConfigDict +from pydantic.functional_validators import field_validator from pydantic import PrivateAttr from pydantic.alias_generators import to_camel import asyncio @@ -15,12 +16,6 @@ APPLICATION_JSON = "application/json" -mime_types = { - "text": "text/plain", - "tasklist": APPLICATION_JSON, - "plotly": APPLICATION_JSON, -} - ElementType = Literal[ "image", "text", @@ -33,31 +28,35 @@ "dataframe", "custom", ] + ElementDisplay = Literal["inline", "side", "page"] ElementSize = Literal["small", "medium", "large"] -class Element(SQLModel, table=True): - __tablename__ = "elements" - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - thread_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True)) - type: str = Field(..., nullable=False) +mime_types: Dict[str, str] = { + "text": "text/plain", + "tasklist": APPLICATION_JSON, + "plotly": APPLICATION_JSON, +} +class ElementBase(SQLModel): + type: ElementType name: str = "" url: Optional[str] = None path: Optional[str] = None object_key: Optional[str] = None chainlit_key: Optional[str] = None - display: str = Field(..., nullable=False) - size: Optional[str] = None + display: ElementDisplay = "inline" + size: Optional[ElementSize] = None language: Optional[str] = None mime: Optional[str] = None - for_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True)) page: Optional[int] = None - props: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) + props: Optional[dict] = None auto_play: Optional[bool] = None - player_config: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) - # Non-DB payload used by runtime logic (private attribute) + player_config: Optional[dict] = None + # runtime-only _content: Optional[Union[str, bytes]] = PrivateAttr(default=None) + _persisted: bool = PrivateAttr(default=False) + _updatable: bool = PrivateAttr(default=False) + _bg_task: Any = PrivateAttr(default=None) model_config = ConfigDict( alias_generator=to_camel, @@ -71,29 +70,6 @@ def content(self) -> Optional[Union[str, bytes]]: @content.setter def content(self, value: Optional[Union[str, bytes]]): self._content = value - - @field_validator("type", mode="before") - def validate_type(cls, v): - allowed = list(get_args(ElementType)) - if v not in allowed: - raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") - return v - - @field_validator("display", mode="before") - def validate_display(cls, v): - allowed = list(get_args(ElementDisplay)) - if v not in allowed: - raise ValueError(f"Invalid display: {v}. Must be one of: {allowed}") - return v - - @field_validator("size", mode="before") - def validate_size(cls, v): - if v is None: - return v - allowed = list(get_args(ElementSize)) - if v not in allowed: - raise ValueError(f"Invalid size: {v}. Must be one of: {allowed}") - return v def to_dict(self): return self.model_dump(by_alias=True) @@ -101,8 +77,11 @@ def to_dict(self): @classmethod def from_dict(cls, **kwargs): # Default to file if missing - type_ = kwargs.get("type", "file") - model = TYPE_MAP.get(type_, File) + t = kwargs.get("type", "file") + if t not in TYPE_MAP: + t = "file" + kwargs["type"] = t + model = TYPE_MAP.get(t, File) return model.model_validate(kwargs) @classmethod @@ -123,28 +102,51 @@ def infer_type_from_mime(cls, mime_type: str): else: return "file" - async def _create(self, persist=True) -> None: - was_persisted = bool(getattr(self, "persisted", False)) - if was_persisted and not getattr(self, "updatable", False): - return None + def _resolve_mime(self) -> None: + # Resolve MIME if needed + if self.mime: + return + key = self.type + if isinstance(key, str) and key in mime_types: + self.mime = mime_types[key] + elif self.path or isinstance(self.content, (bytes, bytearray)): + file_type = filetype.guess(self.path or self.content) + if file_type: + self.mime = file_type.mime + elif self.url: + import mimetypes + self.mime = mimetypes.guess_type(self.url)[0] + + async def _persist_file_if_needed(self) -> None: + # Persist file if needed + if self.url: + return + if not self.chainlit_key or getattr(self, "updatable", False) or self._updatable: + file_dict = await context.session.persist_file( + name=self.name, + path=self.path, + content=self.content, + mime=self.mime or "", + ) + self.chainlit_key = file_dict["id"] + + async def _create(self, persist: bool = True, for_id: Optional[str] = None) -> None: + if self._persisted and not (getattr(self, "updatable", False) or self._updatable): + return None + + self._resolve_mime() + await self._persist_file_if_needed() data_layer = get_data_layer() if data_layer and persist: try: - self._bg_task = asyncio.create_task(data_layer.create_element(self)) + # Map to DB element and persist + db_elem = Element.from_base(self, for_id=for_id) + self._bg_task = asyncio.create_task(data_layer.create_element(db_elem)) except Exception as e: logger.error(f"Failed to create element: {e!s}") - if not self.url and (not self.chainlit_key or getattr(self, "updatable", False)): - file_dict = await context.session.persist_file( - name=self.name, - path=self.path, - content=self.content, - mime=self.mime or "", - ) - self.chainlit_key = file_dict["id"] - - self.persisted = True + self._persisted = True return None async def remove(self): @@ -153,52 +155,34 @@ async def remove(self): await data_layer.delete_element(self.id, self.thread_id) await context.emitter.emit("remove_element", {"id": self.id}) - async def send(self, for_id: str, persist=True): - self.for_id = for_id - - if not self.mime: - if hasattr(self, "type") and self.type in mime_types: - self.mime = mime_types[self.type] - elif self.path or isinstance(self.content, (bytes, bytearray)): - import filetype - file_type = filetype.guess(self.path or self.content) - if file_type: - self.mime = file_type.mime - elif self.url: - import mimetypes - self.mime = mimetypes.guess_type(self.url)[0] - - await self._create(persist=persist) + async def send(self, for_id: str, persist: bool = True): + await self._create(persist=persist, for_id=for_id) if not self.url and not self.chainlit_key: raise ValueError("Must provide url or chainlit key to send element") await context.emitter.send_element(self.to_dict()) -ElementBased = TypeVar("ElementBased", bound=Element) +ElementBased = TypeVar("ElementBased", bound=ElementBase) -# Subclasses for runtime logic (not persisted, but can be instantiated from Element) -class Image(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "image" - size: str = "medium" +# Subclasses for runtime logic (not DB tables) +class Image(ElementBase): + type: Literal["image"] = "image" + size: Optional[ElementSize] = "medium" -class Text(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "text" +class Text(ElementBase): + type: Literal["text"] = "text" language: Optional[str] = None -class Pdf(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "pdf" +class Pdf(ElementBase): + type: Literal["pdf"] = "pdf" mime: str = "application/pdf" page: Optional[int] = None -class Pyplot(Element, table=False): - __tablename__: ClassVar[None] = None +class Pyplot(ElementBase): """Useful to send a pyplot to the UI.""" - type: str = "image" - size: str = "medium" + type: Literal["image"] = "image" + size: Optional[ElementSize] = "medium" figure: Any = Field(default=None, exclude=True) def model_post_init(self, __context) -> None: @@ -214,16 +198,16 @@ def model_post_init(self, __context) -> None: self.content = image.getvalue() super().model_post_init(__context) -class TaskList(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "tasklist" +class TaskList(ElementBase): + type: Literal["tasklist"] = "tasklist" tasks: List[Task] = Field(default_factory=list, exclude=True) status: str = "Ready" name: str = "tasklist" def model_post_init(self, __context) -> None: super().model_post_init(__context) - self.updatable = True + self._updatable = True + setattr(self, "updatable", True) async def add_task(self, task: Task): self.tasks.append(task) @@ -251,24 +235,20 @@ async def preprocess_content(self): ensure_ascii=False, ) -class Audio(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "audio" +class Audio(ElementBase): + type: Literal["audio"] = "audio" auto_play: bool = False -class Video(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "video" - size: str = "medium" +class Video(ElementBase): + type: Literal["video"] = "video" + size: Optional[ElementSize] = "medium" -class File(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "file" - -class Plotly(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "plotly" - size: str = "medium" +class File(ElementBase): + type: Literal["file"] = "file" + +class Plotly(ElementBase): + type: Literal["plotly"] = "plotly" + size: Optional[ElementSize] = "medium" figure: Any = Field(default=None, exclude=True) def model_post_init(self, __context) -> None: @@ -283,10 +263,9 @@ def model_post_init(self, __context) -> None: self.mime = APPLICATION_JSON super().model_post_init(__context) -class Dataframe(Element, table=False): - __tablename__: ClassVar[None] = None - type: str = "dataframe" - size: str = "large" +class Dataframe(ElementBase): + type: Literal["dataframe"] = "dataframe" + size: Optional[ElementSize] = "large" data: Any = Field(default=None, exclude=True) def model_post_init(self, __context) -> None: @@ -297,19 +276,104 @@ def model_post_init(self, __context) -> None: self.content = self.data.to_json(orient="split", date_format="iso") super().model_post_init(__context) -class CustomElement(Element, table=False): - __tablename__: ClassVar[None] = None +class CustomElement(ElementBase): """Useful to send a custom element to the UI.""" - type: str = "custom" + type: Literal["custom"] = "custom" mime: str = APPLICATION_JSON def model_post_init(self, __context) -> None: self.content = json.dumps(self.props) super().model_post_init(__context) - self.updatable = True + self._updatable = True + setattr(self, "updatable", True) async def update(self): - await super().send(self.for_id) + await super().send(for_id="") + +# DB model with table=True +class Element(ElementBase, table=True): + __tablename__ = "elements" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + thread_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True), + ) + for_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True), + ) + # Override Literal fields with DB-mappable types + type: str = Field(..., nullable=False) + display: str = Field(..., nullable=False) + size: Optional[str] = None + props: Optional[dict] = Field(default_factory=dict, sa_type=JSON) + player_config: Optional[dict] = Field(default_factory=dict, sa_type=JSON) + + # Strict validation of DB fields using runtime Literal definitions + @field_validator("type", mode="before") + @classmethod + def _validate_type(cls, v: Any) -> str: + if v is None: + raise ValueError("type is required") + v_str = str(v) + if v_str not in get_args(ElementType): + raise ValueError(f"Invalid type: {v_str}") + return v_str + + @field_validator("display", mode="before") + @classmethod + def _validate_display(cls, v: Any) -> str: + if v is None: + raise ValueError("display is required") + v_str = str(v) + if v_str not in get_args(ElementDisplay): + raise ValueError(f"Invalid display: {v_str}") + return v_str + + @field_validator("size", mode="before") + @classmethod + def _validate_size(cls, v: Any) -> Optional[str]: + if v is None or v == "None": + return None + v_str = str(v) + if v_str not in get_args(ElementSize): + raise ValueError(f"Invalid size: {v_str}") + return v_str + + @classmethod + def from_base(cls, base: ElementBase, for_id: Optional[str] = None) -> "Element": + return cls( + type=str(base.type), + name=base.name, + url=base.url, + path=base.path, + object_key=base.object_key, + chainlit_key=base.chainlit_key, + display=str(base.display), + size=str(base.size) if base.size is not None else None, + language=base.language, + mime=base.mime, + page=base.page, + props=base.props or {}, + auto_play=base.auto_play, + player_config=base.player_config or {}, + for_id=for_id, + ) + + # Validators to enforce allowed values on the DB model + @classmethod + def _allowed(cls, lit) -> List[str]: + return list(get_args(lit)) + + @classmethod + def _validate_choice(cls, value: Optional[str], lit) -> Optional[str]: + if value is None: + return value + allowed = cls._allowed(lit) + if value not in allowed: + raise ValueError(f"Invalid value: {value}. Must be one of: {allowed}") + return value # Simple mapping for type discrimination (Pyplot shares "image", so not included) TYPE_MAP: Dict[str, Any] = { From 3e395634a16d45b773b673f41adb094642889172 Mon Sep 17 00:00:00 2001 From: hayescode <35790761+hayescode@users.noreply.github.com> Date: Wed, 3 Sep 2025 08:19:19 -0500 Subject: [PATCH 9/9] added base for each model --- backend/chainlit/models/__init__.py | 5 - backend/chainlit/models/feedback.py | 37 ++-- backend/chainlit/models/step.py | 258 ++++++++++++++++++++-------- backend/chainlit/models/thread.py | 52 +++--- backend/chainlit/models/user.py | 13 +- 5 files changed, 248 insertions(+), 117 deletions(-) diff --git a/backend/chainlit/models/__init__.py b/backend/chainlit/models/__init__.py index e2b017df11..e69de29bb2 100644 --- a/backend/chainlit/models/__init__.py +++ b/backend/chainlit/models/__init__.py @@ -1,5 +0,0 @@ -from .step import Step as Step -from .user import User, PersistedUser -from .thread import Thread -from .feedback import Feedback, UpdateFeedbackRequest, DeleteFeedbackRequest -from .element import Element, Image, Text, Audio, Video, File, Pyplot, Plotly, CustomElement, Pdf, TaskList, Dataframe \ No newline at end of file diff --git a/backend/chainlit/models/feedback.py b/backend/chainlit/models/feedback.py index 13290f1d75..229ed8d120 100644 --- a/backend/chainlit/models/feedback.py +++ b/backend/chainlit/models/feedback.py @@ -1,18 +1,16 @@ -from typing import Dict, Optional, Literal, get_args +from typing import Optional, Literal from sqlmodel import SQLModel, Field -from pydantic import BaseModel, field_validator, ConfigDict, conint +from pydantic import BaseModel, field_validator, ConfigDict from pydantic.alias_generators import to_camel import uuid from sqlalchemy import Column, ForeignKey, String FeedbackStrategy = Literal["BINARY"] -class Feedback(SQLModel, table=True): - __tablename__ = "feedbacks" - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - for_id: str = Field(sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"))) - value: int = Field(..., ge=0, le=1) + +class FeedbackBase(SQLModel): + for_id: str + value: int thread_id: Optional[str] = None comment: Optional[str] = None @@ -20,9 +18,10 @@ class Feedback(SQLModel, table=True): alias_generator=to_camel, populate_by_name=True, ) - + @field_validator("value", mode="before") - def validate_type(cls, v): + @classmethod + def validate_value(cls, v): allowed = [0, 1] if v not in allowed: raise ValueError(f"Invalid value: {v}. Must be one of: {allowed}") @@ -30,12 +29,24 @@ def validate_type(cls, v): def to_dict(self): data = self.model_dump(by_alias=True) - data.pop("threadId", None) return data + +class Feedback(FeedbackBase, table=True): + __tablename__ = "feedbacks" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + for_id: str = Field( + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE")) + ) + value: int = Field(..., ge=0, le=1) + comment: Optional[str] = None + + class UpdateFeedbackRequest(BaseModel): feedback: Feedback session_id: str - + + class DeleteFeedbackRequest(BaseModel): - feedbackId: str + feedbackId: str diff --git a/backend/chainlit/models/step.py b/backend/chainlit/models/step.py index ee0079f877..5b24e335d5 100644 --- a/backend/chainlit/models/step.py +++ b/backend/chainlit/models/step.py @@ -1,28 +1,34 @@ +import asyncio +import inspect +import json +import uuid import asyncio import inspect import json import uuid from copy import deepcopy from functools import wraps -from typing import Callable, Dict, List, Optional, TypedDict, Union, Literal, Any, get_args +from typing import Callable, Dict, List, Optional, Union, Literal, Any, get_args from sqlmodel import SQLModel, Field from sqlalchemy import Column, JSON, ForeignKey, String -from sqlalchemy.dialects.postgresql import JSONB from pydantic import PrivateAttr from pydantic import field_validator -from literalai import BaseGeneration from pydantic import ConfigDict from pydantic.alias_generators import to_camel + from chainlit.config import config from chainlit.context import CL_RUN_NAMES, context, local_steps from chainlit.data import get_data_layer -from chainlit.element import Element from chainlit.logger import logger -from chainlit.types import FeedbackDict from chainlit.utils import utc_now -from chainlit.context import context + +# Import the Element runtime class via models init to avoid circular import +try: + from chainlit.models import Element # type: ignore +except Exception: # pragma: no cover - optional during partial refactors + Element = Any # fallback for type hints TrueStepType = Literal[ "run", "tool", "llm", "embedding", "retrieval", "rerank", "undefined" @@ -32,54 +38,85 @@ StepType = Union[TrueStepType, MessageStepType] -class Step(SQLModel, table=True): - __tablename__ = "steps" - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - name: str = Field(..., nullable=False) - type: str = Field(..., nullable=False) - thread_id: str = Field(..., sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=False)) - parent_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True)) - disable_feedback: bool = Field(default=False, nullable=False) - streaming: bool = Field(default=False, nullable=False) - wait_for_answer: Optional[bool] = Field(default=None) - is_error: Optional[bool] = Field(default=None) - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) - input: Optional[str] = Field(default=None) - output: Optional[str] = Field(default=None) - created_at: Optional[str] = Field(default=None) - start: Optional[str] = Field(default=None) - end: Optional[str] = Field(default=None) - generation: Optional[dict] = Field(default_factory=dict, sa_column=Column('generation', JSON), alias='generation') - show_input: str = Field(default="json") - language: Optional[str] = Field(default=None) - indent: Optional[int] = Field(default=None) - tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + +class StepBase(SQLModel): + """Runtime Step model. DB fields overridden in Step(table=True).""" + + # Core fields (runtime view). The DB model will override types as str with validators. + name: str = Field(default="") + type: StepType = Field(default="undefined") + + # Optional linkage; DB model defines FKs + thread_id: Optional[str] = None + parent_id: Optional[str] = None + + # Rendering/behavior + disable_feedback: bool = Field(default=False) + streaming: bool = Field(default=False) + wait_for_answer: Optional[bool] = None + is_error: Optional[bool] = None + + # Payload and metadata + input: Optional[str] = None + output: Optional[str] = None + created_at: Optional[str] = None + start: Optional[str] = None + end: Optional[str] = None + generation: Optional[dict] = None + show_input: Union[bool, str] = Field(default="json") + language: Optional[str] = None + indent: Optional[int] = None + tags: Optional[List[str]] = None default_open: Optional[bool] = Field(default=False) - - model_config = ConfigDict( - alias_generator=to_camel, - populate_by_name=True, + metadata_: Optional[dict] = Field( + default_factory=dict, + alias="metadata", + sa_column=Column("metadata", JSON), + schema_extra={"serialization_alias": "metadata"}, ) - - # TODO define relationship with Element - # elements: List[Element] = Relationship(back_populates="step") - # thread: Optional[Thread] = Relationship(back_populates="steps") model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) - # Private attributes for business logic (not persisted or serialized) - _elements: Optional[List[Element]] = PrivateAttr(default_factory=list) + # Private attributes for business logic (not persisted) + _elements: List[Any] = PrivateAttr(default_factory=list) _fail_on_persist_error: bool = PrivateAttr(default=False) _input: str = PrivateAttr(default="") _output: str = PrivateAttr(default="") + _persisted: bool = PrivateAttr(default=False) + + # Convenience properties + @property + def persisted(self) -> bool: + return self._persisted + + @persisted.setter + def persisted(self, v: bool) -> None: + self._persisted = bool(v) + + @property + def elements(self) -> List[Any]: + return self._elements + + @property + def fail_on_persist_error(self) -> bool: + return self._fail_on_persist_error + + @fail_on_persist_error.setter + def fail_on_persist_error(self, v: bool) -> None: + self._fail_on_persist_error = bool(v) @field_validator("type", mode="before") - def validate_type(cls, v): - allowed = [v for arg in get_args(StepType) for v in (get_args(arg) if hasattr(arg, "__args__") else [arg])] + @classmethod + def _validate_type(cls, v: Any) -> Any: + # Accept literals on base; DB class enforces strict string values + allowed = [ + value + for arg in get_args(StepType) + for value in (get_args(arg) if hasattr(arg, "__args__") else [arg]) + ] if v not in allowed: raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") return v @@ -113,6 +150,7 @@ def handle_bytes(item): elif isinstance(item, tuple): return tuple(handle_bytes(i) for i in item) return item + return handle_bytes(content) def _process_content(self, content, set_language=False): @@ -171,12 +209,11 @@ def __enter__(self): previous_steps = local_steps.get() or [] parent_step = previous_steps[-1] if previous_steps else None - if not self.parent_id: - if parent_step: - self.parent_id = parent_step.id + if not self.parent_id and parent_step: + self.parent_id = parent_step.id local_steps.set(previous_steps + [self]) - task = asyncio.create_task(self.send()) + asyncio.create_task(self.send()) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -191,7 +228,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): current_steps.remove(self) local_steps.set(current_steps) - task = asyncio.create_task(self.update()) + asyncio.create_task(self.update()) async def update(self): if self.streaming: @@ -202,7 +239,7 @@ async def update(self): if data_layer: try: - task = asyncio.create_task(data_layer.update_step(step_dict.copy())) + asyncio.create_task(data_layer.update_step(step_dict.copy())) except Exception as e: if self.fail_on_persist_error: raise e @@ -225,7 +262,7 @@ async def remove(self): if data_layer: try: - task = asyncio.create_task(data_layer.delete_step(self.id)) + asyncio.create_task(data_layer.delete_step(self.id)) except Exception as e: if self.fail_on_persist_error: raise e @@ -249,7 +286,7 @@ async def send(self): if data_layer: try: - task = asyncio.create_task(data_layer.create_step(step_dict.copy())) + asyncio.create_task(data_layer.create_step(step_dict.copy())) self.persisted = True except Exception as e: if self.fail_on_persist_error: @@ -259,6 +296,7 @@ async def send(self): tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] await asyncio.gather(*tasks) + from chainlit.context import check_add_step_in_cot if not check_add_step_in_cot(self): await context.emitter.send_step(self.to_dict()) else: @@ -270,6 +308,8 @@ async def stream_token(self, token: str, is_sequence=False, is_input=False): if not token: return + from chainlit.context import check_add_step_in_cot, stub_step + if is_sequence: if is_input: self.input_value = token @@ -296,13 +336,74 @@ async def stream_token(self, token: str, is_sequence=False, is_input=False): id=self.id, token=token, is_sequence=is_sequence, is_input=is_input ) + +class Step(StepBase, table=True): + __tablename__ = "steps" + + # DB identity and relations + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + thread_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True), + ) + parent_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True), + ) + + # Override Literal and complex fields with DB-compatible types/columns + type: str = Field(..., nullable=False) + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + metadata_: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("metadata", JSON), + alias="metadata", + schema_extra={"serialization_alias": "metadata"}, + ) + generation: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("generation", JSON), + alias="generation", + ) + show_input: str + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + @field_validator("type", mode="before") + @classmethod + def _validate_type_db(cls, v: Any) -> str: + if v is None: + raise ValueError("type is required") + v_str = str(v) + allowed = [ + value + for arg in get_args(StepType) + for value in (get_args(arg) if hasattr(arg, "__args__") else [arg]) + ] + if v_str not in allowed: + raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") + return v_str + + @classmethod + def from_base(cls, base: "StepBase") -> "Step": + data = base.model_dump(by_alias=True) + # Map runtime metadata -> metadata_ + if "metadata" in data and data.get("metadata") is not None: + data["metadata_"] = data.pop("metadata") + return cls.model_validate(data) + + def flatten_args_kwargs(func, args, kwargs): signature = inspect.signature(func) bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()} -def check_add_step_in_cot(step: Step): + +def check_add_step_in_cot(step: StepBase): is_message = step.type in [ "user_message", "assistant_message", @@ -312,25 +413,28 @@ def check_add_step_in_cot(step: Step): return False return True -# Step decorator for async and sync functions, now using StepService + def step( - original_function: Optional[Callable] = None, - *, - name: Optional[str] = "", - type: Optional[str] = "undefined", - id: Optional[str] = None, - parent_id: Optional[str] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - language: Optional[str] = None, - show_input: Union[bool, str] = "json", - default_open: bool = False - ) -> Callable: + original_function: Optional[Callable] = None, + *, + name: Optional[str] = "", + type: Optional[str] = "undefined", + id: Optional[str] = None, + parent_id: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + language: Optional[str] = None, + show_input: Union[bool, str] = "json", + default_open: bool = False, +) -> Callable: + """Decorator to wrap functions in a Step context.""" + def wrapper(func: Callable): nonlocal name if not name: name = func.__name__ if inspect.iscoroutinefunction(func): + @wraps(func) async def async_wrapper(*args, **kwargs): async with Step( @@ -343,21 +447,23 @@ async def async_wrapper(*args, **kwargs): show_input=show_input, default_open=default_open, metadata=metadata, - ) as step: + ) as step_obj: try: - step.input = flatten_args_kwargs(func, args, kwargs) + step_obj.input = flatten_args_kwargs(func, args, kwargs) except Exception: pass result = await func(*args, **kwargs) try: - if result and not step.output: - step.output = result + if result and not step_obj.output: + step_obj.output = result except Exception: - step.is_error = True - step.output = str(result) + step_obj.is_error = True + step_obj.output = str(result) return result + return async_wrapper else: + @wraps(func) def sync_wrapper(*args, **kwargs): with Step( @@ -370,20 +476,22 @@ def sync_wrapper(*args, **kwargs): show_input=show_input, default_open=default_open, metadata=metadata, - ) as step: + ) as step_obj: try: - step.input = flatten_args_kwargs(func, args, kwargs) + step_obj.input = flatten_args_kwargs(func, args, kwargs) except Exception: pass result = func(*args, **kwargs) try: - if result and not step.output: - step.output = result + if result and not step_obj.output: + step_obj.output = result except Exception: - step.is_error = True - step.output = str(result) + step_obj.is_error = True + step_obj.output = str(result) return result + return sync_wrapper + func = original_function if not func: return wrapper diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py index 04741fe39e..860be718c6 100644 --- a/backend/chainlit/models/thread.py +++ b/backend/chainlit/models/thread.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Protocol, TypeVar, Union, Self +from typing import Dict, Generic, List, Optional, TypeVar, Self from sqlmodel import SQLModel, Field from pydantic import PrivateAttr, BaseModel import uuid @@ -7,34 +7,34 @@ from pydantic.alias_generators import to_camel from sqlalchemy import Column, JSON, ForeignKey, String -if TYPE_CHECKING: - from chainlit.element import ElementDict - from chainlit.step import StepDict -# Unified thread model -class Thread(SQLModel, table=True): - __tablename__ = "threads" - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) - created_at: str = "" +class ThreadBase(SQLModel): + created_at: Optional[str] = None name: Optional[str] = None - user_id: Optional[str] = Field(default=None, sa_column=Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=True)) + user_id: Optional[str] = None user_identifier: Optional[str] = None - tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) - + tags: Optional[List[str]] = None + # Persisted as JSON column named "metadata", but exposed as `metadata` in the API + metadata_: Optional[dict] = Field( + default_factory=dict, + alias="metadata", + sa_column=Column("metadata", JSON), + schema_extra={"serialization_alias": "metadata"}, + ) + model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) - # Private attributes for business logic (not persisted or serialized) - _steps: Optional[List] = None # You can specify List[Step] if imported - _elements: Optional[List] = None # You can specify List[Element] if imported + # Private runtime attributes + _steps: Optional[List] = None + _elements: Optional[List] = None _runtime_state: dict = PrivateAttr(default_factory=dict) - # Example business logic method def add_tag(self, tag: str): + if self.tags is None: + self.tags = [] if tag not in self.tags: self.tags.append(tag) @@ -46,6 +46,17 @@ def from_dict(cls, **kwargs) -> Self: return cls.model_validate(**kwargs) +class Thread(ThreadBase, table=True): + __tablename__ = "threads" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + user_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=True), + ) + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + + # Pagination and ThreadFilter class Pagination(BaseModel): first: int @@ -82,8 +93,9 @@ def from_dict( cls, paginated_response_dict: Dict ) -> "PaginatedResponse[T]": page_info = PageInfo.from_dict(paginated_response_dict.get("page_info", {})) - data = [the_class.from_dict(d) for d in paginated_response_dict.get("data", [])] - return cls(page_info=page_info, data=data) + # Without runtime type info for T, return data as-is + data_list = paginated_response_dict.get("data", []) + return cls(page_info=page_info, data=data_list) # Thread requests/responses class UpdateThreadRequest(BaseModel): diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py index 95d4ca3085..03549e179e 100644 --- a/backend/chainlit/models/user.py +++ b/backend/chainlit/models/user.py @@ -25,15 +25,20 @@ class User(BaseModel): display_name: Optional[str] = None metadata: Dict = Field(default_factory=dict) -# Persisted user (for database use) + class PersistedUser(SQLModel, table=True): __tablename__ = "users" - + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) identifier: str display_name: Optional[str] = None - metadata_: Optional[dict] = Field(default_factory=dict, sa_column=Column('metadata', JSON), alias='metadata', schema_extra={'serialization_alias': 'metadata'}) - created_at: str = Field(default_factory=utc_now(), primary_key=True) + metadata_: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("metadata", JSON), + alias="metadata", + schema_extra={"serialization_alias": "metadata"}, + ) + created_at: str = Field(default_factory=utc_now) model_config = ConfigDict( alias_generator=to_camel,