From 840894e51a92ca98a8f8e8f449d41986c694ed62 Mon Sep 17 00:00:00 2001 From: ganyang Date: Thu, 18 Sep 2025 13:18:19 +0800 Subject: [PATCH 1/8] feat(vector_stores): Adds support for MariaDB vector stores - Adds the MariaDBConfig configuration class and MariaDB vector store implementation - Updates factory.py and configs.py to support MariaDB - Adds MariaDB-related test cases - Adds the mysql-connector-python dependency to pyproject.toml --- mem0/configs/vector_stores/mariadb.py | 64 ++++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/mariadb.py | 471 ++++++++++++++++++++++++++ pyproject.toml | 1 + tests/vector_stores/test_mariadb.py | 243 +++++++++++++ 6 files changed, 781 insertions(+) create mode 100644 mem0/configs/vector_stores/mariadb.py create mode 100644 mem0/vector_stores/mariadb.py create mode 100644 tests/vector_stores/test_mariadb.py diff --git a/mem0/configs/vector_stores/mariadb.py b/mem0/configs/vector_stores/mariadb.py new file mode 100644 index 0000000000..e8d8cc09ad --- /dev/null +++ b/mem0/configs/vector_stores/mariadb.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, model_validator + + +class MariaDBConfig(BaseModel): + dbname: str = Field("mem0", description="Default name for the database") + collection_name: str = Field("mem0", description="Default name for the collection") + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") + user: Optional[str] = Field(None, description="Database user") + password: Optional[str] = Field(None, description="Database password") + host: Optional[str] = Field(None, description="Database host. Default is localhost") + port: Optional[int] = Field(None, description="Database port. Default is 3306") + distance_function: Optional[str] = Field("euclidean", description="Distance function for vector index ('euclidean' or 'cosine')") + m_value: Optional[int] = Field(16, description="M parameter for HNSW index (3-200). Higher values = more accurate but slower") + # SSL and connection options + ssl_disabled: Optional[bool] = Field(False, description="Disable SSL connection") + ssl_ca: Optional[str] = Field(None, description="SSL CA certificate file path") + ssl_cert: Optional[str] = Field(None, description="SSL certificate file path") + ssl_key: Optional[str] = Field(None, description="SSL key file path") + connection_string: Optional[str] = Field(None, description="MariaDB connection string (overrides individual connection parameters)") + charset: Optional[str] = Field("utf8mb4", description="Character set for the connection") + autocommit: Optional[bool] = Field(True, description="Enable autocommit mode") + + @model_validator(mode="before") + def check_auth_and_connection(cls, values): + # If connection_string is provided, skip validation of individual connection parameters + if values.get("connection_string") is not None: + return values + + # Otherwise, validate individual connection parameters + user, password = values.get("user"), values.get("password") + host, port = values.get("host"), values.get("port") + if not user and not password: + raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.") + if not host and not port: + raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.") + return values + + @model_validator(mode="before") + def validate_distance_function(cls, values): + distance_function = values.get("distance_function", "euclidean") + if distance_function not in ["euclidean", "cosine"]: + raise ValueError("distance_function must be either 'euclidean' or 'cosine'") + return values + + @model_validator(mode="before") + def validate_m_value(cls, values): + m_value = values.get("m_value", 16) + if not (3 <= m_value <= 200): + raise ValueError("m_value must be between 3 and 200") + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 3ac56c8249..2addda3e6e 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -177,6 +177,7 @@ class VectorStoreFactory: "s3_vectors": "mem0.vector_stores.s3_vectors.S3Vectors", "baidu": "mem0.vector_stores.baidu.BaiduDB", "neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector", + "mariadb": "mem0.vector_stores.mariadb.MariaDB", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index ff9fd995fc..eebf512b52 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -32,6 +32,7 @@ class VectorStoreConfig(BaseModel): "faiss": "FAISSConfig", "langchain": "LangchainConfig", "s3_vectors": "S3VectorsConfig", + "mariadb": "MariaDBConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/mariadb.py b/mem0/vector_stores/mariadb.py new file mode 100644 index 0000000000..d623b87fa2 --- /dev/null +++ b/mem0/vector_stores/mariadb.py @@ -0,0 +1,471 @@ +import json +import logging +from contextlib import contextmanager +from typing import Any, List, Optional + +from pydantic import BaseModel + +try: + import mysql.connector +except ImportError as e: + raise ImportError( + "mysql.connector is not available. Please install it using 'pip install mysql-connector-python'" + ) from e + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] + score: Optional[float] + payload: Optional[dict] + + +class MariaDB(VectorStoreBase): + def __init__( + self, + dbname, + collection_name, + embedding_model_dims, + user, + password, + host, + port, + distance_function="euclidean", + m_value=16, + ssl_disabled=False, + ssl_ca=None, + ssl_cert=None, + ssl_key=None, + connection_string=None, + charset="utf8mb4", + autocommit=True, + ): + """ + Initialize the MariaDB Vector database. + + Args: + dbname (str): Database name + collection_name (str): Collection name + embedding_model_dims (int): Dimension of the embedding vector + user (str): Database user + password (str): Database password + host (str): Database host + port (int): Database port + distance_function (str): Distance function for vector index ('euclidean' or 'cosine') + m_value (int): M parameter for HNSW index (3-200). Higher values = more accurate but slower + ssl_disabled (bool): Disable SSL connection + ssl_ca (str, optional): SSL CA certificate file path + ssl_cert (str, optional): SSL certificate file path + ssl_key (str, optional): SSL key file path + connection_string (str, optional): MariaDB connection string (overrides individual connection parameters) + charset (str): Character set for the connection + autocommit (bool): Enable autocommit mode + """ + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.distance_function = distance_function + self.m_value = m_value + + # Connection parameters + if connection_string: + # Parse connection string (simplified parsing) + # Format: mariadb://user:password@host:port/database + import urllib.parse + parsed = urllib.parse.urlparse(connection_string) + self.connection_params = { + 'user': parsed.username, + 'password': parsed.password, + 'host': parsed.hostname, + 'port': parsed.port or 3306, + 'database': parsed.path.lstrip('/') or dbname, + 'charset': charset, + 'autocommit': autocommit, + } + else: + self.connection_params = { + 'user': user, + 'password': password, + 'host': host, + 'port': port or 3306, + 'database': dbname, + 'charset': charset, + 'autocommit': autocommit, + } + + # SSL configuration + if not ssl_disabled: + ssl_config = {} + if ssl_ca: + ssl_config['ca'] = ssl_ca + if ssl_cert: + ssl_config['cert'] = ssl_cert + if ssl_key: + ssl_config['key'] = ssl_key + if ssl_config: + self.connection_params['ssl'] = ssl_config + + # Test connection and create collection if needed + collections = self.list_cols() + if collection_name not in collections: + self.create_col() + + @contextmanager + def _get_connection(self): + """ + Context manager to get a database connection. + """ + conn = None + try: + conn = mysql.connector.connect(**self.connection_params) + yield conn + except Exception as e: + logger.error(f"Database connection error: {e}") + if conn: + conn.rollback() + raise + finally: + if conn: + conn.close() + + + def create_col(self) -> None: + """ + Create a new collection (table in MariaDB). + Will also initialize vector search index. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + # Create table with VECTOR column + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.collection_name} ( + id VARCHAR(255) PRIMARY KEY, + embedding VECTOR({self.embedding_model_dims}) NOT NULL, + payload JSON, + VECTOR INDEX (embedding) M={self.m_value} DISTANCE={self.distance_function} + ) + """) + conn.commit() + logger.info(f"Created collection {self.collection_name} with vector index") + except Exception as e: + logger.error(f"Error creating collection: {e}") + raise + finally: + cursor.close() + + def insert(self, vectors: List[List[float]], payloads=None, ids=None) -> None: + """ + Insert vectors into the collection. + + Args: + vectors: List of vectors to insert + payloads: List of payload dictionaries + ids: List of IDs for the vectors + """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") + + if payloads is None: + payloads = [{}] * len(vectors) + if ids is None: + import uuid + ids = [str(uuid.uuid4()) for _ in vectors] + + with self._get_connection() as conn: + cursor = conn.cursor() + try: + # Insert vectors one by one using VEC_FromText function + for vector_id, vector, payload in zip(ids, vectors, payloads): + # Convert vector to string format for VEC_FromText + vector_str = '[' + ','.join(map(str, vector)) + ']' + payload_json = json.dumps(payload) if payload else None + + cursor.execute(f""" + INSERT INTO {self.collection_name} (id, embedding, payload) + VALUES (%s, VEC_FromText(%s), %s) + """, (vector_id, vector_str, payload_json)) + + conn.commit() + except Exception as e: + logger.error(f"Error inserting vectors: {e}") + conn.rollback() + raise + finally: + cursor.close() + + def search( + self, + query: str, + vectors: List[float], + limit: Optional[int] = 5, + filters: Optional[dict] = None, + ) -> List[OutputData]: + """ + Search for similar vectors using MariaDB Vector distance functions. + + Args: + query (str): Query string (for logging) + vectors (List[float]): Query vector + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + List[OutputData]: Search results. + """ + filter_conditions = [] + filter_params = [] + + if filters: + for k, v in filters.items(): + filter_conditions.append("JSON_EXTRACT(payload, %s) = %s") + filter_params.extend([f"$.{k}", str(v)]) + + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" + + # Convert query vector to string format for VEC_FromText + query_vector_str = '[' + ','.join(map(str, vectors)) + ']' + + with self._get_connection() as conn: + cursor = conn.cursor() + try: + # Use VEC_DISTANCE function which automatically uses the appropriate distance function + if filter_conditions: + query_params = [query_vector_str] + filter_params + [limit] + else: + query_params = [query_vector_str, limit] + + logger.debug(f"SQL query: SELECT id, VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance, payload FROM {self.collection_name} {filter_clause} ORDER BY distance LIMIT %s") + logger.debug(f"Query params: {query_params}") + + cursor.execute(f""" + SELECT id, VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance, payload + FROM {self.collection_name} + {filter_clause} + ORDER BY distance + LIMIT %s + """, query_params) + + results = cursor.fetchall() + return [ + OutputData( + id=str(r[0]), + score=float(r[1]), + payload=json.loads(r[2]) if r[2] else {} + ) + for r in results + ] + except Exception as e: + logger.error(f"Error searching vectors: {e}") + raise + finally: + cursor.close() + + def delete(self, vector_id: str) -> None: + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,)) + conn.commit() + except Exception as e: + logger.error(f"Error deleting vector: {e}") + raise + finally: + cursor.close() + + def update( + self, + vector_id: str, + vector: Optional[List[float]] = None, + payload: Optional[dict] = None, + ) -> None: + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (List[float], optional): Updated vector. + payload (Dict, optional): Updated payload. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + if vector: + vector_str = '[' + ','.join(map(str, vector)) + ']' + cursor.execute( + f"UPDATE {self.collection_name} SET embedding = VEC_FromText(%s) WHERE id = %s", + (vector_str, vector_id), + ) + if payload: + cursor.execute( + f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s", + (json.dumps(payload), vector_id), + ) + conn.commit() + except Exception as e: + logger.error(f"Error updating vector: {e}") + conn.rollback() + raise + finally: + cursor.close() + + def get(self, vector_id: str) -> Optional[OutputData]: + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector data or None if not found. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute( + f"SELECT id, embedding, payload FROM {self.collection_name} WHERE id = %s", + (vector_id,), + ) + result = cursor.fetchone() + if not result: + return None + + payload = json.loads(result[2]) if result[2] else {} + return OutputData(id=str(result[0]), score=None, payload=payload) + except Exception as e: + logger.error(f"Error retrieving vector: {e}") + raise + finally: + cursor.close() + + def list_cols(self) -> List[str]: + """ + List all collections (tables). + + Returns: + List[str]: List of collection names. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SHOW TABLES") + return [row[0] for row in cursor.fetchall()] + except Exception as e: + logger.error(f"Error listing collections: {e}") + raise + finally: + cursor.close() + + def delete_col(self) -> None: + """Delete the collection (table).""" + with self._get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"DROP TABLE IF EXISTS {self.collection_name}") + conn.commit() + logger.info(f"Deleted collection {self.collection_name}") + except Exception as e: + logger.error(f"Error deleting collection: {e}") + raise + finally: + cursor.close() + + def col_info(self) -> dict[str, Any]: + """ + Get information about the collection. + + Returns: + Dict[str, Any]: Collection information. + """ + with self._get_connection() as conn: + cursor = conn.cursor() + try: + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {self.collection_name}") + row_count = cursor.fetchone()[0] + + # Get table size information + cursor.execute(f""" + SELECT + table_name, + ROUND(((data_length + index_length) / 1024 / 1024), 2) AS 'size_mb' + FROM information_schema.tables + WHERE table_schema = DATABASE() AND table_name = %s + """, (self.collection_name,)) + + result = cursor.fetchone() + size_mb = result[1] if result else 0 + + return { + "name": self.collection_name, + "count": row_count, + "size": f"{size_mb} MB" + } + except Exception as e: + logger.error(f"Error getting collection info: {e}") + raise + finally: + cursor.close() + + def list( + self, + filters: Optional[dict] = None, + limit: Optional[int] = 100 + ) -> List[OutputData]: + """ + List all vectors in the collection. + + Args: + filters (Dict, optional): Filters to apply to the list. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + filter_conditions = [] + filter_params = [] + + if filters: + for k, v in filters.items(): + filter_conditions.append("JSON_EXTRACT(payload, %s) = %s") + filter_params.extend([f"$.{k}", str(v)]) + + filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" + + with self._get_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f""" + SELECT id, embedding, payload + FROM {self.collection_name} + {filter_clause} + LIMIT %s + """, (*filter_params, limit)) + + results = cursor.fetchall() + return [ + OutputData( + id=str(r[0]), + score=None, + payload=json.loads(r[2]) if r[2] else {} + ) + for r in results + ] + except Exception as e: + logger.error(f"Error listing vectors: {e}") + raise + finally: + cursor.close() + + def reset(self) -> None: + """Reset the collection by deleting and recreating it.""" + logger.warning(f"Resetting collection {self.collection_name}...") + self.delete_col() + self.create_col() diff --git a/pyproject.toml b/pyproject.toml index 8b9a060c68..5ab285e617 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ vector_stores = [ "elasticsearch>=8.0.0,<9.0.0", "pymilvus>=2.4.0,<2.6.0", "langchain-aws>=0.2.23", + "mysql-connector-python>=9.3.0", ] llms = [ "groq>=0.3.0", diff --git a/tests/vector_stores/test_mariadb.py b/tests/vector_stores/test_mariadb.py new file mode 100644 index 0000000000..9f037ecc0e --- /dev/null +++ b/tests/vector_stores/test_mariadb.py @@ -0,0 +1,243 @@ +import pytest +import struct +from unittest.mock import Mock, patch + +# Create mock modules for mysql and mysql.connector +mock_mysql_connector = Mock() +mock_mysql_connector.connect = Mock() + +mock_mysql = Mock() +mock_mysql.connector = mock_mysql_connector + +# Patch sys.modules before importing +with patch.dict('sys.modules', { + 'mysql': mock_mysql, + 'mysql.connector': mock_mysql_connector +}): + from mem0.vector_stores.mariadb import MariaDB, OutputData + + +class TestMariaDB: + @pytest.fixture + def mock_mariadb_connection(self): + """Mock MariaDB connection and cursor""" + mock_conn = Mock() + mock_cursor = Mock() + mock_conn.cursor.return_value = mock_cursor + + with patch.object(mock_mysql_connector, 'connect', return_value=mock_conn): + yield mock_conn, mock_cursor + + @pytest.fixture + def mariadb_store(self, mock_mariadb_connection): + """Create MariaDB vector store instance with mocked connection""" + mock_conn, mock_cursor = mock_mariadb_connection + + # Mock list_cols to return empty list initially + mock_cursor.fetchall.return_value = [] + + store = MariaDB( + dbname="test_db", + collection_name="test_collection", + embedding_model_dims=3, + user="test_user", + password="test_password", + host="localhost", + port=3306, + ) + return store + + def test_init(self, mariadb_store): + """Test MariaDB initialization""" + assert mariadb_store.collection_name == "test_collection" + assert mariadb_store.embedding_model_dims == 3 + assert mariadb_store.distance_function == "euclidean" + assert mariadb_store.m_value == 16 + + def test_init_with_connection_string(self, mock_mariadb_connection): + """Test initialization with connection string""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchall.return_value = [] + + store = MariaDB( + dbname="test_db", + collection_name="test_collection", + embedding_model_dims=3, + user=None, + password=None, + host=None, + port=None, + connection_string="mariadb://user:pass@localhost:3306/testdb" + ) + + assert store.connection_params['user'] == 'user' + assert store.connection_params['password'] == 'pass' + assert store.connection_params['host'] == 'localhost' + assert store.connection_params['port'] == 3306 + + + def test_create_col(self, mariadb_store, mock_mariadb_connection): + """Test collection creation""" + mock_conn, mock_cursor = mock_mariadb_connection + + mariadb_store.create_col() + + # Verify SQL execution + mock_cursor.execute.assert_called() + mock_conn.commit.assert_called() + + def test_insert(self, mariadb_store, mock_mariadb_connection): + """Test vector insertion""" + mock_conn, mock_cursor = mock_mariadb_connection + + vectors = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + payloads = [{"key": "value1"}, {"key": "value2"}] + ids = ["id1", "id2"] + + mariadb_store.insert(vectors, payloads, ids) + + # Verify execute was called for each vector + assert mock_cursor.execute.call_count >= 2 + mock_conn.commit.assert_called() + + def test_search(self, mariadb_store, mock_mariadb_connection): + """Test vector search""" + mock_conn, mock_cursor = mock_mariadb_connection + + # Mock search results + mock_cursor.fetchall.return_value = [ + ("id1", 0.1, '{"key": "value1"}'), + ("id2", 0.2, '{"key": "value2"}') + ] + + query_vector = [1.0, 2.0, 3.0] + results = mariadb_store.search("test query", query_vector, limit=2) + + assert len(results) == 2 + assert isinstance(results[0], OutputData) + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"key": "value1"} + + def test_search_with_filters(self, mariadb_store, mock_mariadb_connection): + """Test vector search with filters""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchall.return_value = [("id1", 0.1, '{"key": "value1"}')] + + query_vector = [1.0, 2.0, 3.0] + filters = {"category": "test"} + + results = mariadb_store.search("test query", query_vector, limit=1, filters=filters) + + # Verify filter was applied in SQL + mock_cursor.execute.assert_called() + call_args = mock_cursor.execute.call_args + assert "JSON_EXTRACT" in call_args[0][0] + + def test_delete(self, mariadb_store, mock_mariadb_connection): + """Test vector deletion""" + mock_conn, mock_cursor = mock_mariadb_connection + + mariadb_store.delete("test_id") + + mock_cursor.execute.assert_called() + mock_conn.commit.assert_called() + + def test_update(self, mariadb_store, mock_mariadb_connection): + """Test vector update""" + mock_conn, mock_cursor = mock_mariadb_connection + + new_vector = [7.0, 8.0, 9.0] + new_payload = {"updated": "data"} + + mariadb_store.update("test_id", vector=new_vector, payload=new_payload) + + # Should be called twice - once for vector, once for payload + assert mock_cursor.execute.call_count >= 2 + mock_conn.commit.assert_called() + + def test_get(self, mariadb_store, mock_mariadb_connection): + """Test vector retrieval""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchone.return_value = ("test_id", b"vector_data", '{"key": "value"}') + + result = mariadb_store.get("test_id") + + assert isinstance(result, OutputData) + assert result.id == "test_id" + assert result.payload == {"key": "value"} + + def test_get_not_found(self, mariadb_store, mock_mariadb_connection): + """Test vector retrieval when not found""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchone.return_value = None + + result = mariadb_store.get("nonexistent_id") + + assert result is None + + def test_list_cols(self, mariadb_store, mock_mariadb_connection): + """Test listing collections""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchall.return_value = [("table1",), ("table2",)] + + collections = mariadb_store.list_cols() + + assert collections == ["table1", "table2"] + + def test_delete_col(self, mariadb_store, mock_mariadb_connection): + """Test collection deletion""" + mock_conn, mock_cursor = mock_mariadb_connection + + mariadb_store.delete_col() + + mock_cursor.execute.assert_called() + mock_conn.commit.assert_called() + + def test_col_info(self, mariadb_store, mock_mariadb_connection): + """Test collection info retrieval""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchone.side_effect = [ + (100,), # row count + ("test_collection", 5.5) # table info + ] + + info = mariadb_store.col_info() + + assert info["name"] == "test_collection" + assert info["count"] == 100 + assert "MB" in info["size"] + + def test_list(self, mariadb_store, mock_mariadb_connection): + """Test listing vectors""" + mock_conn, mock_cursor = mock_mariadb_connection + mock_cursor.fetchall.return_value = [ + ("id1", b"vector1", '{"key": "value1"}'), + ("id2", b"vector2", '{"key": "value2"}') + ] + + results = mariadb_store.list(limit=10) + + assert len(results) == 2 + assert all(isinstance(r, OutputData) for r in results) + + def test_reset(self, mariadb_store, mock_mariadb_connection): + """Test collection reset""" + mock_conn, mock_cursor = mock_mariadb_connection + + mariadb_store.reset() + + # Should call delete_col and create_col + assert mock_cursor.execute.call_count >= 2 + + def test_import_error(self): + """Test behavior when mysql.connector is not available""" + # Test that ImportError is raised during module import when mysql.connector is not available + with patch.dict('sys.modules', {'mysql.connector': None}): + with pytest.raises(ImportError, match="mysql.connector is not available"): + # Force reimport of the module to trigger the ImportError + import importlib + import sys + if 'mem0.vector_stores.mariadb' in sys.modules: + del sys.modules['mem0.vector_stores.mariadb'] + importlib.import_module('mem0.vector_stores.mariadb') From eaef765abd471482609b8f6e3cd8a7455a843ee9 Mon Sep 17 00:00:00 2001 From: ganyang Date: Thu, 18 Sep 2025 13:36:06 +0800 Subject: [PATCH 2/8] feat(vector_stores): Adds support for MariaDB vector stores - Adds the MariaDBConfig configuration class and MariaDB vector store implementation - Updates factory.py and configs.py to support MariaDB - Adds MariaDB-related test cases - Adds the mysql-connector-python dependency to pyproject.toml --- mem0/vector_stores/mariadb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mem0/vector_stores/mariadb.py b/mem0/vector_stores/mariadb.py index d623b87fa2..fad2973f67 100644 --- a/mem0/vector_stores/mariadb.py +++ b/mem0/vector_stores/mariadb.py @@ -395,7 +395,7 @@ def col_info(self) -> dict[str, Any]: cursor.execute(f""" SELECT table_name, - ROUND(((data_length + index_length) / 1024 / 1024), 2) AS 'size_mb' + ROUND(((data_length + index_length) / 1024 / 1024), 2) AS total_size_mb FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = %s """, (self.collection_name,)) From 3b6a5001a89741502f7c00f71e1d69e6981f117d Mon Sep 17 00:00:00 2001 From: ganyang Date: Thu, 18 Sep 2025 15:46:08 +0800 Subject: [PATCH 3/8] feat(vector_stores): Adds support for Aliyun MySQL vector stores - Adds the MySQLVectorConfig configuration class and MySQL vector store implementation - Updates factory.py and configs.py to support MariaDB - Adds MySQLVector-related test cases - Adds the mysql-connector-python dependency to pyproject.toml --- .../{mariadb.py => aliyun_mysql.py} | 4 +- mem0/utils/factory.py | 2 +- .../{mariadb.py => aliyun_mysql.py} | 12 +- mem0/vector_stores/configs.py | 2 +- .../{test_mariadb.py => test_aliyun_mysql.py} | 120 +++++++++--------- 5 files changed, 70 insertions(+), 70 deletions(-) rename mem0/configs/vector_stores/{mariadb.py => aliyun_mysql.py} (96%) rename mem0/vector_stores/{mariadb.py => aliyun_mysql.py} (97%) rename tests/vector_stores/{test_mariadb.py => test_aliyun_mysql.py} (62%) diff --git a/mem0/configs/vector_stores/mariadb.py b/mem0/configs/vector_stores/aliyun_mysql.py similarity index 96% rename from mem0/configs/vector_stores/mariadb.py rename to mem0/configs/vector_stores/aliyun_mysql.py index e8d8cc09ad..d874b14f9d 100644 --- a/mem0/configs/vector_stores/mariadb.py +++ b/mem0/configs/vector_stores/aliyun_mysql.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator -class MariaDBConfig(BaseModel): +class MySQLVectorConfig(BaseModel): dbname: str = Field("mem0", description="Default name for the database") collection_name: str = Field("mem0", description="Default name for the collection") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") @@ -18,7 +18,7 @@ class MariaDBConfig(BaseModel): ssl_ca: Optional[str] = Field(None, description="SSL CA certificate file path") ssl_cert: Optional[str] = Field(None, description="SSL certificate file path") ssl_key: Optional[str] = Field(None, description="SSL key file path") - connection_string: Optional[str] = Field(None, description="MariaDB connection string (overrides individual connection parameters)") + connection_string: Optional[str] = Field(None, description="Aliyun MySQL connection string (overrides individual connection parameters)") charset: Optional[str] = Field("utf8mb4", description="Character set for the connection") autocommit: Optional[bool] = Field(True, description="Enable autocommit mode") diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 2addda3e6e..c1313586d8 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -177,7 +177,7 @@ class VectorStoreFactory: "s3_vectors": "mem0.vector_stores.s3_vectors.S3Vectors", "baidu": "mem0.vector_stores.baidu.BaiduDB", "neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector", - "mariadb": "mem0.vector_stores.mariadb.MariaDB", + "aliyun_mysql": "mem0.vector_stores.aliyun_mysql.MySQLVector", } @classmethod diff --git a/mem0/vector_stores/mariadb.py b/mem0/vector_stores/aliyun_mysql.py similarity index 97% rename from mem0/vector_stores/mariadb.py rename to mem0/vector_stores/aliyun_mysql.py index fad2973f67..4a7003eebc 100644 --- a/mem0/vector_stores/mariadb.py +++ b/mem0/vector_stores/aliyun_mysql.py @@ -23,7 +23,7 @@ class OutputData(BaseModel): payload: Optional[dict] -class MariaDB(VectorStoreBase): +class MySQLVector(VectorStoreBase): def __init__( self, dbname, @@ -44,7 +44,7 @@ def __init__( autocommit=True, ): """ - Initialize the MariaDB Vector database. + Initialize the Aliyun MySQL Vector database. Args: dbname (str): Database name @@ -60,7 +60,7 @@ def __init__( ssl_ca (str, optional): SSL CA certificate file path ssl_cert (str, optional): SSL certificate file path ssl_key (str, optional): SSL key file path - connection_string (str, optional): MariaDB connection string (overrides individual connection parameters) + connection_string (str, optional): Aliyun MySQL connection string (overrides individual connection parameters) charset (str): Character set for the connection autocommit (bool): Enable autocommit mode """ @@ -72,7 +72,7 @@ def __init__( # Connection parameters if connection_string: # Parse connection string (simplified parsing) - # Format: mariadb://user:password@host:port/database + # Format: mysql://user:password@host:port/database import urllib.parse parsed = urllib.parse.urlparse(connection_string) self.connection_params = { @@ -133,7 +133,7 @@ def _get_connection(self): def create_col(self) -> None: """ - Create a new collection (table in MariaDB). + Create a new collection (table in Aliyun MySQL). Will also initialize vector search index. """ with self._get_connection() as conn: @@ -203,7 +203,7 @@ def search( filters: Optional[dict] = None, ) -> List[OutputData]: """ - Search for similar vectors using MariaDB Vector distance functions. + Search for similar vectors using Aliyun MySQL Vector distance functions. Args: query (str): Query string (for logging) diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index eebf512b52..76ebcac8a4 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -32,7 +32,7 @@ class VectorStoreConfig(BaseModel): "faiss": "FAISSConfig", "langchain": "LangchainConfig", "s3_vectors": "S3VectorsConfig", - "mariadb": "MariaDBConfig", + "aliyun_mysql": "MySQLVectorConfig", } @model_validator(mode="after") diff --git a/tests/vector_stores/test_mariadb.py b/tests/vector_stores/test_aliyun_mysql.py similarity index 62% rename from tests/vector_stores/test_mariadb.py rename to tests/vector_stores/test_aliyun_mysql.py index 9f037ecc0e..e555b9a4c0 100644 --- a/tests/vector_stores/test_mariadb.py +++ b/tests/vector_stores/test_aliyun_mysql.py @@ -14,13 +14,13 @@ 'mysql': mock_mysql, 'mysql.connector': mock_mysql_connector }): - from mem0.vector_stores.mariadb import MariaDB, OutputData + from mem0.vector_stores.aliyun_mysql import MySQLVector, OutputData -class TestMariaDB: +class TestMySQLVector: @pytest.fixture - def mock_mariadb_connection(self): - """Mock MariaDB connection and cursor""" + def mock_mysql_connection(self): + """Mock MySQL connection and cursor""" mock_conn = Mock() mock_cursor = Mock() mock_conn.cursor.return_value = mock_cursor @@ -29,14 +29,14 @@ def mock_mariadb_connection(self): yield mock_conn, mock_cursor @pytest.fixture - def mariadb_store(self, mock_mariadb_connection): - """Create MariaDB vector store instance with mocked connection""" - mock_conn, mock_cursor = mock_mariadb_connection + def mysql_store(self, mock_mysql_connection): + """Create MySQL vector store instance with mocked connection""" + mock_conn, mock_cursor = mock_mysql_connection # Mock list_cols to return empty list initially mock_cursor.fetchall.return_value = [] - store = MariaDB( + store = MySQLVector( dbname="test_db", collection_name="test_collection", embedding_model_dims=3, @@ -47,19 +47,19 @@ def mariadb_store(self, mock_mariadb_connection): ) return store - def test_init(self, mariadb_store): - """Test MariaDB initialization""" - assert mariadb_store.collection_name == "test_collection" - assert mariadb_store.embedding_model_dims == 3 - assert mariadb_store.distance_function == "euclidean" - assert mariadb_store.m_value == 16 + def test_init(self, mysql_store): + """Test MySQL initialization""" + assert mysql_store.collection_name == "test_collection" + assert mysql_store.embedding_model_dims == 3 + assert mysql_store.distance_function == "euclidean" + assert mysql_store.m_value == 16 - def test_init_with_connection_string(self, mock_mariadb_connection): + def test_init_with_connection_string(self, mock_mysql_connection): """Test initialization with connection string""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchall.return_value = [] - store = MariaDB( + store = MySQLVector( dbname="test_db", collection_name="test_collection", embedding_model_dims=3, @@ -67,7 +67,7 @@ def test_init_with_connection_string(self, mock_mariadb_connection): password=None, host=None, port=None, - connection_string="mariadb://user:pass@localhost:3306/testdb" + connection_string="mysql://user:pass@localhost:3306/testdb" ) assert store.connection_params['user'] == 'user' @@ -76,33 +76,33 @@ def test_init_with_connection_string(self, mock_mariadb_connection): assert store.connection_params['port'] == 3306 - def test_create_col(self, mariadb_store, mock_mariadb_connection): + def test_create_col(self, mysql_store, mock_mysql_connection): """Test collection creation""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection - mariadb_store.create_col() + mysql_store.create_col() # Verify SQL execution mock_cursor.execute.assert_called() mock_conn.commit.assert_called() - def test_insert(self, mariadb_store, mock_mariadb_connection): + def test_insert(self, mysql_store, mock_mysql_connection): """Test vector insertion""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection vectors = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] payloads = [{"key": "value1"}, {"key": "value2"}] ids = ["id1", "id2"] - mariadb_store.insert(vectors, payloads, ids) + mysql_store.insert(vectors, payloads, ids) # Verify execute was called for each vector assert mock_cursor.execute.call_count >= 2 mock_conn.commit.assert_called() - def test_search(self, mariadb_store, mock_mariadb_connection): + def test_search(self, mysql_store, mock_mysql_connection): """Test vector search""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection # Mock search results mock_cursor.fetchall.return_value = [ @@ -111,7 +111,7 @@ def test_search(self, mariadb_store, mock_mariadb_connection): ] query_vector = [1.0, 2.0, 3.0] - results = mariadb_store.search("test query", query_vector, limit=2) + results = mysql_store.search("test query", query_vector, limit=2) assert len(results) == 2 assert isinstance(results[0], OutputData) @@ -119,113 +119,113 @@ def test_search(self, mariadb_store, mock_mariadb_connection): assert results[0].score == 0.1 assert results[0].payload == {"key": "value1"} - def test_search_with_filters(self, mariadb_store, mock_mariadb_connection): + def test_search_with_filters(self, mysql_store, mock_mysql_connection): """Test vector search with filters""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchall.return_value = [("id1", 0.1, '{"key": "value1"}')] query_vector = [1.0, 2.0, 3.0] filters = {"category": "test"} - results = mariadb_store.search("test query", query_vector, limit=1, filters=filters) + results = mysql_store.search("test query", query_vector, limit=1, filters=filters) # Verify filter was applied in SQL mock_cursor.execute.assert_called() call_args = mock_cursor.execute.call_args assert "JSON_EXTRACT" in call_args[0][0] - def test_delete(self, mariadb_store, mock_mariadb_connection): + def test_delete(self, mysql_store, mock_mysql_connection): """Test vector deletion""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection - mariadb_store.delete("test_id") + mysql_store.delete("test_id") mock_cursor.execute.assert_called() mock_conn.commit.assert_called() - def test_update(self, mariadb_store, mock_mariadb_connection): + def test_update(self, mysql_store, mock_mysql_connection): """Test vector update""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection new_vector = [7.0, 8.0, 9.0] new_payload = {"updated": "data"} - mariadb_store.update("test_id", vector=new_vector, payload=new_payload) + mysql_store.update("test_id", vector=new_vector, payload=new_payload) # Should be called twice - once for vector, once for payload assert mock_cursor.execute.call_count >= 2 mock_conn.commit.assert_called() - def test_get(self, mariadb_store, mock_mariadb_connection): + def test_get(self, mysql_store, mock_mysql_connection): """Test vector retrieval""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchone.return_value = ("test_id", b"vector_data", '{"key": "value"}') - result = mariadb_store.get("test_id") + result = mysql_store.get("test_id") assert isinstance(result, OutputData) assert result.id == "test_id" assert result.payload == {"key": "value"} - def test_get_not_found(self, mariadb_store, mock_mariadb_connection): + def test_get_not_found(self, mysql_store, mock_mysql_connection): """Test vector retrieval when not found""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchone.return_value = None - result = mariadb_store.get("nonexistent_id") + result = mysql_store.get("nonexistent_id") assert result is None - def test_list_cols(self, mariadb_store, mock_mariadb_connection): + def test_list_cols(self, mysql_store, mock_mysql_connection): """Test listing collections""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchall.return_value = [("table1",), ("table2",)] - collections = mariadb_store.list_cols() + collections = mysql_store.list_cols() assert collections == ["table1", "table2"] - def test_delete_col(self, mariadb_store, mock_mariadb_connection): + def test_delete_col(self, mysql_store, mock_mysql_connection): """Test collection deletion""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection - mariadb_store.delete_col() + mysql_store.delete_col() mock_cursor.execute.assert_called() mock_conn.commit.assert_called() - def test_col_info(self, mariadb_store, mock_mariadb_connection): + def test_col_info(self, mysql_store, mock_mysql_connection): """Test collection info retrieval""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchone.side_effect = [ (100,), # row count ("test_collection", 5.5) # table info ] - info = mariadb_store.col_info() + info = mysql_store.col_info() assert info["name"] == "test_collection" assert info["count"] == 100 assert "MB" in info["size"] - def test_list(self, mariadb_store, mock_mariadb_connection): + def test_list(self, mysql_store, mock_mysql_connection): """Test listing vectors""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection mock_cursor.fetchall.return_value = [ ("id1", b"vector1", '{"key": "value1"}'), ("id2", b"vector2", '{"key": "value2"}') ] - results = mariadb_store.list(limit=10) + results = mysql_store.list(limit=10) assert len(results) == 2 assert all(isinstance(r, OutputData) for r in results) - def test_reset(self, mariadb_store, mock_mariadb_connection): + def test_reset(self, mysql_store, mock_mysql_connection): """Test collection reset""" - mock_conn, mock_cursor = mock_mariadb_connection + mock_conn, mock_cursor = mock_mysql_connection - mariadb_store.reset() + mysql_store.reset() # Should call delete_col and create_col assert mock_cursor.execute.call_count >= 2 @@ -238,6 +238,6 @@ def test_import_error(self): # Force reimport of the module to trigger the ImportError import importlib import sys - if 'mem0.vector_stores.mariadb' in sys.modules: - del sys.modules['mem0.vector_stores.mariadb'] - importlib.import_module('mem0.vector_stores.mariadb') + if 'mem0.vector_stores.aliyun_mysql' in sys.modules: + del sys.modules['mem0.vector_stores.aliyun_mysql'] + importlib.import_module('mem0.vector_stores.aliyun_mysql') From 33192ac06fdc88812aad0ce7c34ee26cdfef4908 Mon Sep 17 00:00:00 2001 From: ganyang Date: Mon, 22 Sep 2025 13:24:56 +0800 Subject: [PATCH 4/8] feat(vector_stores): Renamed aliyun_mysql to alibabacloud_mysql - Changed the file and class names related to aliyun_mysql to alibabacloud_mysql - Updated related import paths and configuration names - Retained the original functionality and logic --- .../{aliyun_mysql.py => alibabacloud_mysql.py} | 2 +- mem0/utils/factory.py | 2 +- .../{aliyun_mysql.py => alibabacloud_mysql.py} | 8 ++++---- mem0/vector_stores/configs.py | 2 +- .../{test_aliyun_mysql.py => test_alibabacloud_mysql.py} | 8 ++++---- 5 files changed, 11 insertions(+), 11 deletions(-) rename mem0/configs/vector_stores/{aliyun_mysql.py => alibabacloud_mysql.py} (97%) rename mem0/vector_stores/{aliyun_mysql.py => alibabacloud_mysql.py} (97%) rename tests/vector_stores/{test_aliyun_mysql.py => test_alibabacloud_mysql.py} (96%) diff --git a/mem0/configs/vector_stores/aliyun_mysql.py b/mem0/configs/vector_stores/alibabacloud_mysql.py similarity index 97% rename from mem0/configs/vector_stores/aliyun_mysql.py rename to mem0/configs/vector_stores/alibabacloud_mysql.py index d874b14f9d..7db331f02c 100644 --- a/mem0/configs/vector_stores/aliyun_mysql.py +++ b/mem0/configs/vector_stores/alibabacloud_mysql.py @@ -18,7 +18,7 @@ class MySQLVectorConfig(BaseModel): ssl_ca: Optional[str] = Field(None, description="SSL CA certificate file path") ssl_cert: Optional[str] = Field(None, description="SSL certificate file path") ssl_key: Optional[str] = Field(None, description="SSL key file path") - connection_string: Optional[str] = Field(None, description="Aliyun MySQL connection string (overrides individual connection parameters)") + connection_string: Optional[str] = Field(None, description="AlibabaCloud MySQL connection string (overrides individual connection parameters)") charset: Optional[str] = Field("utf8mb4", description="Character set for the connection") autocommit: Optional[bool] = Field(True, description="Enable autocommit mode") diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index c1313586d8..7dfe87d3b4 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -177,7 +177,7 @@ class VectorStoreFactory: "s3_vectors": "mem0.vector_stores.s3_vectors.S3Vectors", "baidu": "mem0.vector_stores.baidu.BaiduDB", "neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector", - "aliyun_mysql": "mem0.vector_stores.aliyun_mysql.MySQLVector", + "alibabacloud_mysql": "mem0.vector_stores.alibabacloud_mysql.MySQLVector", } @classmethod diff --git a/mem0/vector_stores/aliyun_mysql.py b/mem0/vector_stores/alibabacloud_mysql.py similarity index 97% rename from mem0/vector_stores/aliyun_mysql.py rename to mem0/vector_stores/alibabacloud_mysql.py index 4a7003eebc..cb5becddd0 100644 --- a/mem0/vector_stores/aliyun_mysql.py +++ b/mem0/vector_stores/alibabacloud_mysql.py @@ -44,7 +44,7 @@ def __init__( autocommit=True, ): """ - Initialize the Aliyun MySQL Vector database. + Initialize the AlibabaCloud MySQL Vector database. Args: dbname (str): Database name @@ -60,7 +60,7 @@ def __init__( ssl_ca (str, optional): SSL CA certificate file path ssl_cert (str, optional): SSL certificate file path ssl_key (str, optional): SSL key file path - connection_string (str, optional): Aliyun MySQL connection string (overrides individual connection parameters) + connection_string (str, optional): AlibabaCloud MySQL connection string (overrides individual connection parameters) charset (str): Character set for the connection autocommit (bool): Enable autocommit mode """ @@ -133,7 +133,7 @@ def _get_connection(self): def create_col(self) -> None: """ - Create a new collection (table in Aliyun MySQL). + Create a new collection (table in AlibabaCloud MySQL). Will also initialize vector search index. """ with self._get_connection() as conn: @@ -203,7 +203,7 @@ def search( filters: Optional[dict] = None, ) -> List[OutputData]: """ - Search for similar vectors using Aliyun MySQL Vector distance functions. + Search for similar vectors using AlibabaCloud MySQL Vector distance functions. Args: query (str): Query string (for logging) diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 76ebcac8a4..f605e28a0d 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -32,7 +32,7 @@ class VectorStoreConfig(BaseModel): "faiss": "FAISSConfig", "langchain": "LangchainConfig", "s3_vectors": "S3VectorsConfig", - "aliyun_mysql": "MySQLVectorConfig", + "alibabacloud_mysql": "MySQLVectorConfig", } @model_validator(mode="after") diff --git a/tests/vector_stores/test_aliyun_mysql.py b/tests/vector_stores/test_alibabacloud_mysql.py similarity index 96% rename from tests/vector_stores/test_aliyun_mysql.py rename to tests/vector_stores/test_alibabacloud_mysql.py index e555b9a4c0..9122c0d805 100644 --- a/tests/vector_stores/test_aliyun_mysql.py +++ b/tests/vector_stores/test_alibabacloud_mysql.py @@ -14,7 +14,7 @@ 'mysql': mock_mysql, 'mysql.connector': mock_mysql_connector }): - from mem0.vector_stores.aliyun_mysql import MySQLVector, OutputData + from mem0.vector_stores.alibabacloud_mysql import MySQLVector, OutputData class TestMySQLVector: @@ -238,6 +238,6 @@ def test_import_error(self): # Force reimport of the module to trigger the ImportError import importlib import sys - if 'mem0.vector_stores.aliyun_mysql' in sys.modules: - del sys.modules['mem0.vector_stores.aliyun_mysql'] - importlib.import_module('mem0.vector_stores.aliyun_mysql') + if 'mem0.vector_stores.alibabacloud_mysql' in sys.modules: + del sys.modules['mem0.vector_stores.alibabacloud_mysql'] + importlib.import_module('mem0.vector_stores.alibabacloud_mysql') From a50e3f9362844d39ab6603236b1a03b82e14d146 Mon Sep 17 00:00:00 2001 From: ganyang Date: Sun, 19 Oct 2025 11:40:05 +0800 Subject: [PATCH 5/8] test(test_alibabacloud_mysql): Modified the search test to verify filter application. - Removed unused imports of struct - Adjusted the search method call to remove the variable that receives the result - Retained the assertion check in mock_cursor.execute to verify that the filter is applied in the SQL query --- tests/vector_stores/test_alibabacloud_mysql.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/vector_stores/test_alibabacloud_mysql.py b/tests/vector_stores/test_alibabacloud_mysql.py index 9122c0d805..8813664be0 100644 --- a/tests/vector_stores/test_alibabacloud_mysql.py +++ b/tests/vector_stores/test_alibabacloud_mysql.py @@ -1,5 +1,4 @@ import pytest -import struct from unittest.mock import Mock, patch # Create mock modules for mysql and mysql.connector @@ -127,7 +126,7 @@ def test_search_with_filters(self, mysql_store, mock_mysql_connection): query_vector = [1.0, 2.0, 3.0] filters = {"category": "test"} - results = mysql_store.search("test query", query_vector, limit=1, filters=filters) + mysql_store.search("test query", query_vector, limit=1, filters=filters) # Verify filter was applied in SQL mock_cursor.execute.assert_called() From c8190494ad935b28dfa998d7c3b3020d4af2c2f9 Mon Sep 17 00:00:00 2001 From: ganyang Date: Sun, 19 Oct 2025 11:42:44 +0800 Subject: [PATCH 6/8] test(test_alibabacloud_mysql): Modified the search test to verify filter application. - Removed unused imports of struct - Adjusted the search method call to remove the variable that receives the result - Retained the assertion check in mock_cursor.execute to verify that the filter is applied in the SQL query --- mem0/vector_stores/alibabacloud_mysql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mem0/vector_stores/alibabacloud_mysql.py b/mem0/vector_stores/alibabacloud_mysql.py index cb5becddd0..56efbafbab 100644 --- a/mem0/vector_stores/alibabacloud_mysql.py +++ b/mem0/vector_stores/alibabacloud_mysql.py @@ -392,7 +392,7 @@ def col_info(self) -> dict[str, Any]: row_count = cursor.fetchone()[0] # Get table size information - cursor.execute(f""" + cursor.execute(""" SELECT table_name, ROUND(((data_length + index_length) / 1024 / 1024), 2) AS total_size_mb From 436c09b7a2c6b8cf957a01c8e6f1f26dd944c271 Mon Sep 17 00:00:00 2001 From: ganyang Date: Sun, 19 Oct 2025 13:03:44 +0800 Subject: [PATCH 7/8] docs (components): Adds support for the AlibabaCloud MySQL vector database - Adds a guide to installing and using the AlibabaCloud MySQL database - Provides sample code for various connection methods, including standard connections, SSL connections, and connection string methods - Provides a detailed list of configuration parameters and their default values - Introduces database features, such as native vector support, HNSW indexes, and various distance functions - Explains prerequisites and connection string formats - Explains distance functions and performance tuning parameters - Describes error handling mechanisms --- .../vectordbs/dbs/alibabacloud_mysql.mdx | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 docs/components/vectordbs/dbs/alibabacloud_mysql.mdx diff --git a/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx b/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx new file mode 100644 index 0000000000..fe7c6ff71c --- /dev/null +++ b/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx @@ -0,0 +1,162 @@ +[AlibabaCloud MySQL](https://www.alibabacloud.com/product/apsaradb-for-rds-mysql) is a fully managed MySQL database service that supports vector operations through MySQL's native VECTOR data type and vector functions. It provides high-performance vector similarity search capabilities with HNSW indexing. + +### Installation +```bash +pip install mysql-connector-python +``` + +### Usage + + +```python Python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "alibabacloud_mysql", + "config": { + "dbname": "vector_db", + "collection_name": "memories", + "user": "your_username", + "password": "your_password", + "host": "your-mysql-host.mysql.rds.aliyuncs.com", + "port": 3306, + "embedding_model_dims": 1536, + "distance_function": "euclidean", + "m_value": 16 + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movie? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +```python Python (with connection string) +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "alibabacloud_mysql", + "config": { + "connection_string": "mysql://username:password@host:3306/database", + "collection_name": "memories", + "distance_function": "cosine" + } + } +} + +m = Memory.from_config(config) +``` + +```python Python (with SSL) +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "alibabacloud_mysql", + "config": { + "dbname": "vector_db", + "collection_name": "memories", + "user": "your_username", + "password": "your_password", + "host": "your-mysql-host.mysql.rds.aliyuncs.com", + "port": 3306, + "embedding_model_dims": 1536, + "ssl_ca": "/path/to/ca-cert.pem", + "ssl_cert": "/path/to/client-cert.pem", + "ssl_key": "/path/to/client-key.pem" + } + } +} + +m = Memory.from_config(config) +``` + + +### Config + +Here are the parameters available for configuring AlibabaCloud MySQL: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `dbname` | The name of the database | `None` | +| `collection_name` | The name of the collection (table) | `None` | +| `embedding_model_dims` | Dimensions of the embedding model | `None` | +| `user` | Database username | `None` | +| `password` | Database password | `None` | +| `host` | Database host address | `None` | +| `port` | Database port | `3306` | +| `distance_function` | Distance function for vector index ('euclidean' or 'cosine') | `euclidean` | +| `m_value` | M parameter for HNSW index (3-200). Higher values = more accurate but slower | `16` | +| `ssl_disabled` | Disable SSL connection | `False` | +| `ssl_ca` | SSL CA certificate file path | `None` | +| `ssl_cert` | SSL certificate file path | `None` | +| `ssl_key` | SSL key file path | `None` | +| `connection_string` | MySQL connection string (overrides individual connection parameters) | `None` | +| `charset` | Character set for the connection | `utf8mb4` | +| `autocommit` | Enable autocommit mode | `True` | + +### Features + +- **Native Vector Support**: Uses MySQL's native VECTOR data type for optimal performance +- **HNSW Indexing**: Supports Hierarchical Navigable Small World (HNSW) indexing for fast similarity search +- **Multiple Distance Functions**: Supports both Euclidean and Cosine distance functions +- **SSL Support**: Full SSL/TLS encryption support for secure connections +- **JSON Payload**: Store additional metadata as JSON alongside vectors +- **Connection Pooling**: Efficient connection management with context managers +- **Flexible Configuration**: Support for connection strings or individual parameters + +### Prerequisites + +1. **AlibabaCloud MySQL Instance**: You need an AlibabaCloud RDS MySQL instance with vector support enabled +2. **Database Setup**: Ensure your MySQL instance supports the VECTOR data type and vector functions +3. **Network Access**: Configure security groups and whitelist to allow connections from your application + +### Connection String Format + +The connection string should follow this format: +``` +mysql://username:password@host:port/database +``` + +Example: +``` +mysql://myuser:mypass@rm-xxxxxxxx.mysql.rds.aliyuncs.com:3306/vector_db +``` + +### Distance Functions + +- **euclidean**: Uses Euclidean distance for vector similarity +- **cosine**: Uses Cosine distance for vector similarity + +The distance function is set during collection creation and affects the HNSW index configuration. + +### Performance Tuning + +- **m_value**: Controls the HNSW index quality vs speed tradeoff + - Lower values (3-16): Faster indexing, less memory, lower accuracy + - Higher values (32-200): Slower indexing, more memory, higher accuracy +- **embedding_model_dims**: Should match your embedding model's output dimensions exactly + +### Error Handling + +The connector includes comprehensive error handling and logging: +- Connection errors are automatically caught and logged +- Failed transactions are rolled back automatically +- Detailed error messages help with debugging From 5c0bf4d1a167c8a1d5b10b9ea3908095829cf621 Mon Sep 17 00:00:00 2001 From: ganyang Date: Sun, 19 Oct 2025 13:10:53 +0800 Subject: [PATCH 8/8] docs (components): Adds support for the AlibabaCloud MySQL vector database - Adds a guide to installing and using the AlibabaCloud MySQL database - Provides sample code for various connection methods, including standard connections, SSL connections, and connection string methods - Provides a detailed list of configuration parameters and their default values - Introduces database features, such as native vector support, HNSW indexes, and various distance functions - Explains prerequisites and connection string formats - Explains distance functions and performance tuning parameters - Describes error handling mechanisms --- docs/components/vectordbs/dbs/alibabacloud_mysql.mdx | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx b/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx index fe7c6ff71c..76d4c7d068 100644 --- a/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx +++ b/docs/components/vectordbs/dbs/alibabacloud_mysql.mdx @@ -1,10 +1,5 @@ [AlibabaCloud MySQL](https://www.alibabacloud.com/product/apsaradb-for-rds-mysql) is a fully managed MySQL database service that supports vector operations through MySQL's native VECTOR data type and vector functions. It provides high-performance vector similarity search capabilities with HNSW indexing. -### Installation -```bash -pip install mysql-connector-python -``` - ### Usage