From 5e1b62fc2f87928c0f607e2156b9a2f9744a089b Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Wed, 4 Dec 2024 16:48:39 -0500 Subject: [PATCH 01/31] chore: update notebooks with small typos (#25) * chore: llama_index_doc_store.ipynb * chore: update llama_index_vector_store.ipynb * chore: Update llama_index_vector_store.ipynb --- samples/llama_index_doc_store.ipynb | 4 ++-- samples/llama_index_vector_store.ipynb | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index b7b02e5..8c9e78b 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresDocumentStore` & `PostgresIndexStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store documents and indexes with the `PostgresDocumentStore` and `PostgresIndexStore` classes.\n", "\n", @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index a8efe40..fd8cc3e 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -8,13 +8,13 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresVectorStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's Langchain integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store vector embeddings with the `PostgresVectorStore` class.\n", "\n", "Learn more about the package on [GitHub](https://github.com/googleapis/llama-index-cloud-sql-pg-python/).\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/docs/vector_store.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googleapis/llama-index-cloud-sql-pg-python/blob/main/samples/llama_index_vector_store.ipynb)" ] }, { @@ -40,7 +40,7 @@ "id": "IR54BmgvdHT_" }, "source": [ - "### 🦜🔗 Library Installation\n", + "### 🦙 Library Installation\n", "Install the integration library, `llama-index-cloud-sql-pg`, and the library for the embedding service, `llama-index-embeddings-vertex`." ] }, From 802c44324685663c109bde031a2dff3fb4767d34 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 5 Dec 2024 13:16:17 -0500 Subject: [PATCH 02/31] chore: adhere to PEP 585 and removed unused imports (#26) * chore: adhere to PEP 585 and removed unused imports * chore: lint with isort --- .../async_document_store.py | 40 +++++++++---------- .../async_index_store.py | 9 ++--- .../async_vector_store.py | 37 ++++++++--------- .../document_store.py | 34 ++++++++-------- src/llama_index_cloud_sql_pg/engine.py | 26 ++++-------- src/llama_index_cloud_sql_pg/index_store.py | 10 ++--- src/llama_index_cloud_sql_pg/indexes.py | 4 +- src/llama_index_cloud_sql_pg/vector_store.py | 30 +++++++------- tests/test_async_vector_store.py | 2 +- tests/test_async_vector_store_index.py | 4 +- tests/test_vector_store.py | 2 +- tests/test_vector_store_index.py | 6 +-- 12 files changed, 91 insertions(+), 113 deletions(-) diff --git a/src/llama_index_cloud_sql_pg/async_document_store.py b/src/llama_index_cloud_sql_pg/async_document_store.py index 9a1060b..d20c8ef 100644 --- a/src/llama_index_cloud_sql_pg/async_document_store.py +++ b/src/llama_index_cloud_sql_pg/async_document_store.py @@ -16,7 +16,7 @@ import json import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional, Sequence from llama_index.core.constants import DATA_KEY from llama_index.core.schema import BaseNode @@ -119,13 +119,13 @@ async def __afetch_query(self, query): return results async def _put_all_doc_hashes_to_table( - self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) + self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE) ) -> None: """Puts a multiple rows of node ids with their doc_hash into the document table. Incase a row with the id already exists, it updates the row with the new doc_hash. Args: - rows (List[Tuple[str, str]]): List of tuples of id and doc_hash + rows (list[tuple[str, str]]): List of tuples of id and doc_hash batch_size (int): batch_size to insert the rows. Defaults to 1. Returns: @@ -173,7 +173,7 @@ async def async_add_documents( """Adds a document to the store. Args: - docs (List[BaseDocument]): documents + docs (list[BaseDocument]): documents allow_update (bool): allow update of docstore from document batch_size (int): batch_size to insert the rows. Defaults to 1. store_text (bool): allow the text content of the node to stored. @@ -225,11 +225,11 @@ async def async_add_documents( await self.__aexecute_query(query, batch) @property - async def adocs(self) -> Dict[str, BaseNode]: + async def adocs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" list_docs = await self.__afetch_query(query) @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]: return RefDocInfo(node_ids=node_ids, metadata=merged_metadata) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool: async def _get_ref_doc_child_node_ids( self, ref_doc_id: str - ) -> Optional[Dict[str, List[str]]]: + ) -> Optional[dict[str, list[str]]]: """Helper function to find the child node mappings of a ref_doc_id. Returns: Optional[ - Dict[ + dict[ str, # Ref_doc_id - List # List of all nodes that refer to ref_doc_id + list # List of all nodes that refer to ref_doc_id ] ]""" query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';""" @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None: await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)]) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]: else: return None - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: return hashes @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ raise NotImplementedError( @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_index_store.py b/src/llama_index_cloud_sql_pg/async_index_store.py index fde312e..ee06c53 100644 --- a/src/llama_index_cloud_sql_pg/async_index_store.py +++ b/src/llama_index_cloud_sql_pg/async_index_store.py @@ -16,9 +16,8 @@ import json import warnings -from typing import List, Optional +from typing import Optional -from llama_index.core.constants import DATA_KEY from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore from llama_index.core.storage.index_store.utils import ( @@ -113,11 +112,11 @@ async def __afetch_query(self, query): await conn.commit() return results - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";""" @@ -190,7 +189,7 @@ async def aget_index_struct( return json_to_index_struct(index_data) return None - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." ) diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py index 82b9857..20baead 100644 --- a/src/llama_index_cloud_sql_pg/async_vector_store.py +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -15,14 +15,10 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -import base64 import json -import re -import uuid import warnings -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence -import numpy as np from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -31,7 +27,6 @@ MetadataFilter, MetadataFilters, VectorStoreQuery, - VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.core.vector_stores.utils import ( @@ -70,7 +65,7 @@ def __init__( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -88,7 +83,7 @@ def __init__( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -120,7 +115,7 @@ def __init__( @classmethod async def create( - cls: Type[AsyncPostgresVectorStore], + cls: type[AsyncPostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -128,7 +123,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -146,7 +141,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -233,7 +228,7 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" ids = [] metadata_col_names = ( @@ -292,14 +287,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" if not node_ids and not filters: return - all_filters: List[MetadataFilter | MetadataFilters] = [] + all_filters: list[MetadataFilter | MetadataFilters] = [] if node_ids: all_filters.append( MetadataFilter( @@ -331,9 +326,9 @@ async def aclear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" query = VectorStoreQuery( node_ids=node_ids, filters=filters, similarity_top_k=-1 @@ -365,7 +360,7 @@ async def aquery( similarities.append(row["distance"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -377,7 +372,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -392,9 +387,9 @@ def clear(self) -> None: def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: raise NotImplementedError( "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." ) @@ -486,7 +481,7 @@ async def __query_columns( **kwargs: Any, ) -> Sequence[RowMapping]: """Perform search query on database.""" - filters: List[MetadataFilter | MetadataFilters] = [] + filters: list[MetadataFilter | MetadataFilters] = [] if query.doc_ids: filters.append( MetadataFilter( diff --git a/src/llama_index_cloud_sql_pg/document_store.py b/src/llama_index_cloud_sql_pg/document_store.py index 020128e..f4ff3db 100644 --- a/src/llama_index_cloud_sql_pg/document_store.py +++ b/src/llama_index_cloud_sql_pg/document_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Dict, List, Optional, Sequence, Type +from typing import Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.storage.docstore import BaseDocumentStore @@ -55,7 +55,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -83,7 +83,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresDocumentStore], + cls: type[PostgresDocumentStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -110,11 +110,11 @@ def create_sync( return cls(cls.__create_key, engine, document_store) @property - def docs(self) -> Dict[str, BaseNode]: + def docs(self) -> dict[str, BaseNode]: """Get all documents. Returns: - Dict[str, BaseDocument]: documents + dict[str, BaseDocument]: documents """ return self._engine._run_as_sync(self.__document_store.adocs) @@ -291,11 +291,11 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None: self.__document_store.aset_document_hash(doc_id, doc_hash) ) - async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -304,11 +304,11 @@ async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None: self.__document_store.aset_document_hashes(doc_hashes) ) - def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None: + def set_document_hashes(self, doc_hashes: dict[str, str]) -> None: """Set the hash for a given doc_id. Args: - doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value. + doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value. Returns: None @@ -343,11 +343,11 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: self.__document_store.aget_document_hash(doc_id) ) - async def aget_all_document_hashes(self) -> Dict[str, str]: + async def aget_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -356,11 +356,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - def get_all_document_hashes(self) -> Dict[str, str]: + def get_all_document_hashes(self) -> dict[str, str]: """Get the stored hash for all documents. Returns: - Dict[ + dict[ str, # doc_hash str # doc_id ] @@ -369,12 +369,12 @@ def get_all_document_hashes(self) -> Dict[str, str]: self.__document_store.aget_all_document_hashes() ) - async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] @@ -384,12 +384,12 @@ async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: self.__document_store.aget_all_ref_doc_info() ) - def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]: + def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]: """Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents. Returns: Optional[ - Dict[ + dict[ str, #Ref_doc_id RefDocInfo, #Ref_doc_info of the id ] diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index b0a3ed5..cc067db 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -17,17 +17,7 @@ from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union import aiohttp import google.auth # type: ignore @@ -75,7 +65,7 @@ async def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" async with aiohttp.ClientSession() as client: response = await client.get(url, raise_for_status=True) - response_json: Dict = await response.json() + response_json: dict = await response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -511,7 +501,7 @@ async def _ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -527,7 +517,7 @@ async def _ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -584,7 +574,7 @@ async def ainit_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -600,7 +590,7 @@ async def ainit_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -635,7 +625,7 @@ def init_vector_store_table( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -651,7 +641,7 @@ def init_vector_store_table( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". diff --git a/src/llama_index_cloud_sql_pg/index_store.py b/src/llama_index_cloud_sql_pg/index_store.py index cb41b49..3103d54 100644 --- a/src/llama_index_cloud_sql_pg/index_store.py +++ b/src/llama_index_cloud_sql_pg/index_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from llama_index.core.data_structs.data_structs import IndexStruct from llama_index.core.storage.index_store.types import BaseIndexStore @@ -96,20 +96,20 @@ def create_sync( index_store = engine._run_as_sync(coro) return cls(cls.__create_key, engine, index_store) - async def aindex_structs(self) -> List[IndexStruct]: + async def aindex_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return await self._engine._run_as_async(self.__index_store.aindex_structs()) - def index_structs(self) -> List[IndexStruct]: + def index_structs(self) -> list[IndexStruct]: """Get all index structs. Returns: - List[IndexStruct]: index structs + list[IndexStruct]: index structs """ return self._engine._run_as_sync(self.__index_store.aindex_structs()) diff --git a/src/llama_index_cloud_sql_pg/indexes.py b/src/llama_index_cloud_sql_pg/indexes.py index 1367d51..9e9de00 100644 --- a/src/llama_index_cloud_sql_pg/indexes.py +++ b/src/llama_index_cloud_sql_pg/indexes.py @@ -15,7 +15,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional @dataclass @@ -44,7 +44,7 @@ class BaseIndex(ABC): distance_strategy: DistanceStrategy = field( default_factory=lambda: DistanceStrategy.COSINE_DISTANCE ) - partial_indexes: Optional[List[str]] = None + partial_indexes: Optional[list[str]] = None @abstractmethod def index_options(self) -> str: diff --git a/src/llama_index_cloud_sql_pg/vector_store.py b/src/llama_index_cloud_sql_pg/vector_store.py index 7e3cf4d..2e9f244 100644 --- a/src/llama_index_cloud_sql_pg/vector_store.py +++ b/src/llama_index_cloud_sql_pg/vector_store.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Sequence, Type +from typing import Any, Optional, Sequence from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import ( @@ -71,7 +71,7 @@ def __init__( @classmethod async def create( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -79,7 +79,7 @@ async def create( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -97,7 +97,7 @@ async def create( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -138,7 +138,7 @@ async def create( @classmethod def create_sync( - cls: Type[PostgresVectorStore], + cls: type[PostgresVectorStore], engine: PostgresEngine, table_name: str, schema_name: str = "public", @@ -146,7 +146,7 @@ def create_sync( text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: str = "li_metadata", - metadata_columns: List[str] = [], + metadata_columns: list[str] = [], ref_doc_id_column: str = "ref_doc_id", node_column: str = "node_data", stores_text: bool = True, @@ -164,7 +164,7 @@ def create_sync( text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node_data". stores_text (bool): Whether the table stores text. Defaults to "True". @@ -212,11 +212,11 @@ def client(self) -> Any: """Get client.""" return self._engine - async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]: """Asynchronously add nodes to the table.""" return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) - def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]: """Synchronously add nodes to the table.""" return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) @@ -230,7 +230,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: async def adelete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -241,7 +241,7 @@ async def adelete_nodes( def delete_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: @@ -260,17 +260,17 @@ def clear(self) -> None: async def aget_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) def get_nodes( self, - node_ids: Optional[List[str]] = None, + node_ids: Optional[list[str]] = None, filters: Optional[MetadataFilters] = None, - ) -> List[BaseNode]: + ) -> list[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 45b8b6a..c38ceee 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index af86315..1f392fc 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -15,13 +15,11 @@ import os import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping from llama_index_cloud_sql_pg import PostgresEngine from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index b6939b0..c365730 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -14,7 +14,7 @@ import os import uuid -from typing import List, Sequence +from typing import Sequence import pytest import pytest_asyncio diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index 87eff8a..ba63004 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -14,16 +14,12 @@ import os -import sys import uuid -from typing import List, Sequence import pytest import pytest_asyncio -from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.schema import TextNode from sqlalchemy import text -from sqlalchemy.engine.row import RowMapping -from sqlalchemy.ext.asyncio import create_async_engine from llama_index_cloud_sql_pg import PostgresEngine, PostgresVectorStore from llama_index_cloud_sql_pg.indexes import ( # type: ignore From 47fef5fec26260a2144450f8d76f68956ebc01d2 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 5 Dec 2024 19:27:22 +0100 Subject: [PATCH 03/31] fix(deps): update python-nonmajor (#22) Co-authored-by: Vishwaraj Anand Co-authored-by: Averi Kitsch --- pyproject.toml | 6 +++--- requirements.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c56b7f..c7615d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,11 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ - "black[jupyter]==24.8.0", + "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.11.2", + "mypy==1.13.0", "pytest-asyncio==0.24.0", - "pytest==8.3.3", + "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index efd3592..a366819 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.12.1 -llama-index-core==0.12.0 +cloud-sql-python-connector[asyncpg]==1.14.0 +llama-index-core==0.12.2 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From ee760e920f5cc3e02709b1bc7b5c00865c2c9aca Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 6 Dec 2024 13:22:47 +0100 Subject: [PATCH 04/31] chore(deps): update dependency llama-index-core to v0.12.3 (#30) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a366819..6f643b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.14.0 -llama-index-core==0.12.2 +llama-index-core==0.12.3 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From 34516568ddd6f9e63f91754879fd86ab5ecedcd3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 12 Dec 2024 19:04:36 +0100 Subject: [PATCH 05/31] chore(deps): update python-nonmajor (#31) --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6f643b0..196a1d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.14.0 -llama-index-core==0.12.3 +cloud-sql-python-connector[asyncpg]==1.15.0 +llama-index-core==0.12.5 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From ff4503bf842135c63263d60318d36939c7ef09e9 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 19 Dec 2024 00:57:56 +0530 Subject: [PATCH 06/31] chore(ci): add Cloud Build failure reporter (#34) * chore(ci): add Cloud Build failure reporter * chore: refer to langchain alloy db workflow --- .github/workflows/schedule_reporter.yml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .github/workflows/schedule_reporter.yml diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 0000000..ab846ef --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,25 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + uses: googleapis/langchain-google-alloydb-pg-python/.github/workflows/cloud_build_failure_reporter.yml@main + with: + trigger_names: "integration-test-nightly,continuous-test-on-merge" From 5085182088f9d3311017abfccae64959791499a3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 3 Jan 2025 11:07:54 +0100 Subject: [PATCH 07/31] chore(deps): update python-nonmajor (#33) --- pyproject.toml | 2 +- requirements.txt | 2 +- tests/test_async_document_store.py | 4 ++-- tests/test_document_store.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c7615d4..58adf9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ test = [ "black[jupyter]==24.10.0", "isort==5.13.2", "mypy==1.13.0", - "pytest-asyncio==0.24.0", + "pytest-asyncio==0.25.0", "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index 196a1d7..8fe0878 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.5 +llama-index-core==0.12.6 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index d47d70a..d04db53 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -159,7 +159,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -176,7 +176,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. diff --git a/tests/test_document_store.py b/tests/test_document_store.py index c8d86df..61d6786 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -133,7 +133,7 @@ async def test_async_add_document(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, async_engine, doc_store): # Create a document @@ -150,7 +150,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): query = f"""select * from "public"."{default_table_name_async}" where id = '{doc.doc_id}';""" results = await afetch(async_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. @@ -431,7 +431,7 @@ async def test_add_document(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_add_hash_before_data(self, sync_engine, sync_doc_store): # Create a document @@ -448,7 +448,7 @@ async def test_add_hash_before_data(self, sync_engine, sync_doc_store): query = f"""select * from "public"."{default_table_name_sync}" where id = '{doc.doc_id}';""" results = await afetch(sync_engine, query) result = results[0] - assert result["node_data"][DATA_KEY]["text"] == document_text + assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text async def test_ref_doc_exists(self, sync_doc_store): # Create a ref_doc & a doc and add them to the store. From 12cccb83a376cc93c23f979e2d16335a0757dc09 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 7 Jan 2025 01:15:46 +0530 Subject: [PATCH 08/31] chore(test): index tests throwing DuplicateTableError due to undeleted index (#37) --- tests/test_vector_store_index.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index ba63004..a0ac37c 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -32,7 +32,6 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") -CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX VECTOR_SIZE = 5 @@ -112,14 +111,15 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) vs = PostgresVectorStore.create_sync( engine, table_name=DEFAULT_TABLE, ) - await vs.async_add(nodes) - + vs.add(nodes) vs.drop_vector_index() yield vs @@ -127,6 +127,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() vs.apply_vector_index(index) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -135,6 +136,7 @@ async def test_areindex(self, vs): vs.reindex() vs.reindex(DEFAULT_INDEX_NAME) assert vs.is_valid_index(DEFAULT_INDEX_NAME) + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_dropindex(self, vs): vs.drop_vector_index() @@ -152,6 +154,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): vs.apply_vector_index(index) assert vs.is_valid_index("secondindex") vs.drop_vector_index("secondindex") + vs.drop_vector_index(DEFAULT_INDEX_NAME) async def test_is_valid_index(self, vs): is_valid = vs.is_valid_index("invalid_index") @@ -198,7 +201,9 @@ async def engine(self, db_project, db_region, db_instance, db_name): @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + await engine.ainit_vector_store_table( + DEFAULT_TABLE_ASYNC, VECTOR_SIZE, overwrite_existing=True + ) vs = await PostgresVectorStore.create( engine, table_name=DEFAULT_TABLE_ASYNC, @@ -212,6 +217,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_areindex(self, vs): if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): @@ -220,6 +226,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -242,16 +249,3 @@ async def test_aapply_vector_index_ivfflat(self, vs): async def test_is_valid_index(self, vs): is_valid = await vs.ais_valid_index("invalid_index") assert is_valid == False - - async def test_aapply_vector_index_ivf(self, vs): - index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) - await vs.aapply_vector_index(index, concurrently=True) - assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) - index = IVFFlatIndex( - name="secondindex", - distance_strategy=DistanceStrategy.INNER_PRODUCT, - ) - await vs.aapply_vector_index(index) - assert await vs.ais_valid_index("secondindex") - await vs.adrop_vector_index("secondindex") - await vs.adrop_vector_index() From de19120aa661432b8d4a7fde9ba6cc3b8efa23d3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:32:34 -0800 Subject: [PATCH 09/31] chore(deps): bump jinja2 from 3.1.4 to 3.1.5 in /.kokoro (#36) Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.4 to 3.1.5. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.4...3.1.5) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 88fb726..23e61f6 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -277,9 +277,9 @@ jeepney==0.8.0 \ # via # keyring # secretstorage -jinja2==3.1.4 \ - --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ - --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d +jinja2==3.1.5 \ + --hash=sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb \ + --hash=sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb # via gcp-releasetool keyring==24.3.1 \ --hash=sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db \ @@ -525,4 +525,4 @@ zipp==3.19.1 \ # WARNING: The following packages were not pinned, but pip requires them to be # pinned when the requirements file includes hashes and the requirement is not # satisfied by a package already installed. Consider using the --allow-unsafe flag. -# setuptools \ No newline at end of file +# setuptools From 6ce6ba103aa9ed51649814e6e7283755ec0bcab3 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 7 Jan 2025 23:36:25 +0530 Subject: [PATCH 10/31] ci: Add blunderbuss config (#41) --- .github/blunderbuss.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .github/blunderbuss.yml diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml new file mode 100644 index 0000000..90d8c91 --- /dev/null +++ b/.github/blunderbuss.yml @@ -0,0 +1,4 @@ +assign_issues: + - googleapis/llama-index-cloud-sql +assign_prs: + - googleapis/llama-index-cloud-sql From 0ef1fa5c945c9012354fc6cacb4fc50dd12c0c19 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 00:01:37 +0000 Subject: [PATCH 11/31] feat: Add chat store init methods (#39) Co-authored-by: Averi Kitsch --- src/llama_index_cloud_sql_pg/engine.py | 85 ++++++++++++++++++++++++++ tests/test_engine.py | 36 +++++++++++ 2 files changed, 121 insertions(+) diff --git a/src/llama_index_cloud_sql_pg/engine.py b/src/llama_index_cloud_sql_pg/engine.py index cc067db..2faa943 100644 --- a/src/llama_index_cloud_sql_pg/engine.py +++ b/src/llama_index_cloud_sql_pg/engine.py @@ -756,6 +756,91 @@ def init_index_store_table( ) ) + async def _ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat history. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + if overwrite_existing: + async with self._pool.connect() as conn: + await conn.execute( + text(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + ) + await conn.commit() + + create_table_query = f"""CREATE TABLE "{schema_name}"."{table_name}"( + id SERIAL PRIMARY KEY, + key VARCHAR NOT NULL, + message JSON NOT NULL + );""" + create_index_query = f"""CREATE INDEX "{table_name}_idx_key" ON "{schema_name}"."{table_name}" (key);""" + async with self._pool.connect() as conn: + await conn.execute(text(create_table_query)) + await conn.execute(text(create_index_query)) + await conn.commit() + + async def ainit_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + await self._run_as_async( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + + def init_chat_store_table( + self, + table_name: str, + schema_name: str = "public", + overwrite_existing: bool = False, + ) -> None: + """ + Create an table to save chat store. + Args: + table_name (str): The table name to store chat store. + schema_name (str): The schema name to store the chat store table. + Default: "public". + overwrite_existing (bool): Whether to drop existing table. + Default: False. + Returns: + None + """ + self._run_as_sync( + self._ainit_chat_store_table( + table_name, + schema_name, + overwrite_existing, + ) + ) + async def _aload_table_schema( self, table_name: str, schema_name: str = "public" ) -> Table: diff --git a/tests/test_engine.py b/tests/test_engine.py index fe89197..46af5d0 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -34,6 +34,8 @@ DEFAULT_IS_TABLE_SYNC = "index_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE = "vector_store_" + str(uuid.uuid4()) DEFAULT_VS_TABLE_SYNC = "vector_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE = "chat_store_" + str(uuid.uuid4()) +DEFAULT_CS_TABLE_SYNC = "chat_store_" + str(uuid.uuid4()) VECTOR_SIZE = 768 @@ -113,6 +115,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -296,6 +299,22 @@ async def test_init_index_store(self, engine): for row in results: assert row in expected + async def test_init_chat_store(self, engine): + await engine.ainit_chat_store_table( + table_name=DEFAULT_CS_TABLE, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected + @pytest.mark.asyncio class TestEngineSync: @@ -343,6 +362,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() async def test_password( @@ -461,3 +481,19 @@ async def test_init_index_store(self, engine): ] for row in results: assert row in expected + + async def test_init_chat_store(self, engine): + engine.init_chat_store_table( + table_name=DEFAULT_CS_TABLE_SYNC, + schema_name="public", + overwrite_existing=True, + ) + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{DEFAULT_CS_TABLE_SYNC}';" + results = await afetch(engine, stmt) + expected = [ + {"column_name": "id", "data_type": "integer"}, + {"column_name": "key", "data_type": "character varying"}, + {"column_name": "message", "data_type": "json"}, + ] + for row in results: + assert row in expected From 2b14f5a946e595bce145bf1b526138cf393250ed Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:07:03 +0000 Subject: [PATCH 12/31] feat: Add Async Chat Store (#38) * feat: Add Async Chat Store * fix tests --------- Co-authored-by: Averi Kitsch --- .../async_chat_store.py | 295 ++++++++++++++++++ tests/test_async_chat_store.py | 218 +++++++++++++ 2 files changed, 513 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/async_chat_store.py create mode 100644 tests/test_async_chat_store.py diff --git a/src/llama_index_cloud_sql_pg/async_chat_store.py b/src/llama_index_cloud_sql_pg/async_chat_store.py new file mode 100644 index 0000000..8d80543 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_chat_store.py @@ -0,0 +1,295 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + + +class AsyncPostgresChatStore(BaseChatStore): + """Chat Store Table stored in an CloudSQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + ): + """AsyncPostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AsyncPostgresChatStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> AsyncPostgresChatStore: + """Create a new AsyncPostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. + Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + AsyncPostgresChatStore: A newly created instance of AsyncPostgresChatStore. + """ + table_schema = await engine._aload_table_schema(table_name, schema_name) + column_names = table_schema.columns.keys() + + required_columns = ["id", "key", "message"] + + if not (all(x in column_names for x in required_columns)): + raise ValueError( + f"Table '{schema_name}'.'{table_name}' has an incorrect schema.\n" + f"Expected column names: {required_columns}\n" + f"Provided column names: {column_names}\n" + "Please create the table with the following schema:\n" + f"CREATE TABLE {schema_name}.{table_name} (\n" + " id SERIAL PRIMARY KEY,\n" + " key VARCHAR NOT NULL,\n" + " message JSON NOT NULL\n" + ");" + ) + + return cls(cls.__create_key, engine._pool, table_name, schema_name) + + async def __aexecute_query(self, query, params=None): + async with self._engine.connect() as conn: + await conn.execute(text(query), params) + await conn.commit() + + async def __afetch_query(self, query): + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + await conn.commit() + return results + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "AsyncPostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}'; """ + await self.__aexecute_query(query) + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + + params = [ + { + "key": key, + "message": json.dumps(message.dict()), + } + for message in messages + ] + + await self.__aexecute_query(insert_query, params) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + query = f"""SELECT message from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return [] + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + insert_query = f""" + INSERT INTO "{self._schema_name}"."{self._table_name}" (key, message) + VALUES (:key, :message);""" + params = {"key": key, "message": json.dumps(message.dict())} + + await self.__aexecute_query(insert_query, params) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' RETURNING *; """ + results = await self.__afetch_query(query) + if results: + return [ + ChatMessage.model_validate(result.get("message")) for result in results + ] + return None + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id;""" + results = await self.__afetch_query(query) + if results: + if idx >= len(results): + return None + id_to_be_deleted = results[idx].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + query = f"""SELECT * from "{self._schema_name}"."{self._table_name}" WHERE key = '{key}' ORDER BY id DESC LIMIT 1;""" + results = await self.__afetch_query(query) + if results: + id_to_be_deleted = results[0].get("id") + delete_query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE id = '{id_to_be_deleted}' RETURNING *;""" + result = await self.__afetch_query(delete_query) + result = result[0] + if result: + return ChatMessage.model_validate(result.get("message")) + return None + return None + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + query = ( + f"""SELECT distinct key from "{self._schema_name}"."{self._table_name}";""" + ) + results = await self.__afetch_query(query) + keys = [] + if results: + keys = [row.get("key") for row in results] + return keys + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def add_message(self, key: str, message: ChatMessage) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) + + def get_keys(self) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresChatStore . Use PostgresChatStore interface instead." + ) diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py new file mode 100644 index 0000000..dd22ac1 --- /dev/null +++ b/tests/test_async_chat_store.py @@ -0,0 +1,218 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_chat_store import AsyncPostgresChatStore + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresChatStores: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def chat_store(self, async_engine): + await async_engine._ainit_chat_store_table(table_name=default_table_name_async) + + chat_store = await AsyncPostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncPostgresChatStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_async_add_message(self, async_engine, chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await chat_store.aset_messages(key, messages) + + results = await chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await chat_store.aset_messages(key, messages) + + await chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await chat_store.aset_messages(key_1, message_1) + await chat_store.aset_messages(key_2, message_2) + + keys = await chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From 7787d7d1161dd994c11ac8a75eb5890cf9309cee Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 8 Jan 2025 17:28:52 +0000 Subject: [PATCH 13/31] feat: Add Postgres Chat Store (#40) * feat: Add Postgres Chat Store * Linter fix --- src/llama_index_cloud_sql_pg/__init__.py | 2 + src/llama_index_cloud_sql_pg/chat_store.py | 289 ++++++++++++++++ tests/test_chat_store.py | 382 +++++++++++++++++++++ 3 files changed, 673 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/chat_store.py create mode 100644 tests/test_chat_store.py diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index 2916607..4a367b5 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .chat_store import PostgresChatStore from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore @@ -20,6 +21,7 @@ _all = [ "Column", + "PostgresChatStore", "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", diff --git a/src/llama_index_cloud_sql_pg/chat_store.py b/src/llama_index_cloud_sql_pg/chat_store.py new file mode 100644 index 0000000..db277e9 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/chat_store.py @@ -0,0 +1,289 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List, Optional + +from llama_index.core.llms import ChatMessage +from llama_index.core.storage.chat_store.base import BaseChatStore + +from .async_chat_store import AsyncPostgresChatStore +from .engine import PostgresEngine + + +class PostgresChatStore(BaseChatStore): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: PostgresEngine, chat_store: AsyncPostgresChatStore + ): + """PostgresChatStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (PostgresEngine): Database connection pool. + chat_store (AsyncPostgresChatStore): The async only IndexStore implementation + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != PostgresChatStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__() + self._engine = engine + self.__chat_store = chat_store + + @classmethod + async def create( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def create_sync( + cls, + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + ) -> PostgresChatStore: + """Create a new PostgresChatStore sync instance. + + Args: + engine (PostgresEngine): Postgres engine to use. + table_name (str): Table name that stores the chat store. + schema_name (str): The schema name where the table is located. Defaults to "public" + + Raises: + ValueError: If the table provided does not contain required schema. + + Returns: + PostgresChatStore: A newly created instance of PostgresChatStore. + """ + coro = AsyncPostgresChatStore.create(engine, table_name, schema_name) + chat_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, chat_store) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresChatStore" + + async def aset_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Asynchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return await self._engine._run_as_async( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + async def aget_messages(self, key: str) -> List[ChatMessage]: + """Asynchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return await self._engine._run_as_async( + self.__chat_store.aget_messages(key=key) + ) + + async def async_add_message(self, key: str, message: ChatMessage) -> None: + """Asynchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return await self._engine._run_as_async( + self.__chat_store.async_add_message(key=key, message=message) + ) + + async def adelete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Asynchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_messages(key=key) + ) + + async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Asynchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: + """Asynchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return await self._engine._run_as_async( + self.__chat_store.adelete_last_message(key=key) + ) + + async def aget_keys(self) -> List[str]: + """Asynchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return await self._engine._run_as_async(self.__chat_store.aget_keys()) + + def set_messages(self, key: str, messages: List[ChatMessage]) -> None: + """Synchronously sets the chat messages for a specific key. + + Args: + key (str): A unique identifier for the chat. + messages (List[ChatMessage]): A list of `ChatMessage` objects to upsert. + + Returns: + None + + """ + return self._engine._run_as_sync( + self.__chat_store.aset_messages(key=key, messages=messages) + ) + + def get_messages(self, key: str) -> List[ChatMessage]: + """Synchronously retrieves the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for which the messages are to be retrieved. + + Returns: + List[ChatMessage]: A list of `ChatMessage` objects associated with the provided key. + If no messages are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_messages(key=key)) + + def add_message(self, key: str, message: ChatMessage) -> None: + """Synchronously adds a new chat message to the specified key. + + Args: + key (str): A unique identifierfor the chat to which the message is added. + message (ChatMessage): The `ChatMessage` object that is to be added. + + Returns: + None + """ + return self._engine._run_as_sync( + self.__chat_store.async_add_message(key=key, message=message) + ) + + def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: + """Synchronously deletes the chat messages associated with a specific key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + + Returns: + Optional[List[ChatMessage]]: A list of `ChatMessage` objects that were deleted, or `None` if no messages + were associated with the key or could be deleted. + """ + return self._engine._run_as_sync(self.__chat_store.adelete_messages(key=key)) + + def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: + """Synchronously deletes a specific chat message by index from the messages associated with a given key. + + Args: + key (str): A unique identifier for the chat whose messages are to be deleted. + idx (int): The index of the `ChatMessage` to be deleted from the list of messages. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_message(key=key, idx=idx) + ) + + def delete_last_message(self, key: str) -> Optional[ChatMessage]: + """Synchronously deletes the last chat message associated with a given key. + + Args: + key (str): A unique identifier for the chat whose message is to be deleted. + + Returns: + Optional[ChatMessage]: The `ChatMessage` object that was deleted, or `None` if no message + was associated with the key or could be deleted. + """ + return self._engine._run_as_sync( + self.__chat_store.adelete_last_message(key=key) + ) + + def get_keys(self) -> List[str]: + """Synchronously retrieves a list of all keys. + + Returns: + Optional[str]: A list of strings representing the keys. If no keys are found, an empty list is returned. + """ + return self._engine._run_as_sync(self.__chat_store.aget_keys()) diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py new file mode 100644 index 0000000..305948c --- /dev/null +++ b/tests/test_chat_store.py @@ -0,0 +1,382 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import warnings +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.llms import ChatMessage +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresChatStore, PostgresEngine + +default_table_name_async = "chat_store_" + str(uuid.uuid4()) +default_table_name_sync = "chat_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def async_chat_store(self, async_engine): + await async_engine.ainit_chat_store_table(table_name=default_table_name_async) + + async_chat_store = await PostgresChatStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield async_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=async_engine, table_name=default_table_name_async) + + async def test_async_add_message(self, async_engine, async_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + await async_chat_store.async_add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_aset_and_aget_messages(self, async_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + await async_chat_store.aset_messages(key, messages) + + results = await async_chat_store.aget_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_adelete_messages(self, async_engine, async_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_messages(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 0 + + async def test_adelete_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_message(key, 1) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_adelete_last_message(self, async_engine, async_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + await async_chat_store.aset_messages(key, messages) + + await async_chat_store.adelete_last_message(key) + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}' ORDER BY id;""" + results = await afetch(async_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_aget_keys(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + await async_chat_store.aset_messages(key_1, message_1) + await async_chat_store.aset_messages(key_2, message_2) + + keys = await async_chat_store.aget_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, async_engine, async_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + await async_chat_store.aset_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + await async_chat_store.aset_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_async}" where key = '{key}';""" + results = await afetch(async_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresChatStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine(self, db_project, db_region, db_instance, db_name): + sync_engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await sync_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def sync_chat_store(self, sync_engine): + sync_engine.init_chat_store_table(table_name=default_table_name_sync) + + sync_chat_store = PostgresChatStore.create_sync( + engine=sync_engine, table_name=default_table_name_sync + ) + + yield sync_chat_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(sync_engine, query) + + async def test_init_with_constructor(self, sync_engine): + with pytest.raises(Exception): + PostgresChatStore(engine=sync_engine, table_name=default_table_name_sync) + + async def test_add_message(self, sync_engine, sync_chat_store): + key = "test_add_key" + + message = ChatMessage(content="add_message_test", role="user") + sync_chat_store.add_message(key, message=message) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + result = results[0] + assert result["message"] == message.dict() + + async def test_set_and_get_messages(self, sync_chat_store): + message_1 = ChatMessage(content="First message", role="user") + message_2 = ChatMessage(content="Second message", role="user") + messages = [message_1, message_2] + key = "test_set_and_get_key" + sync_chat_store.set_messages(key, messages) + + results = sync_chat_store.get_messages(key) + + assert len(results) == 2 + assert results[0].content == message_1.content + assert results[1].content == message_2.content + + async def test_delete_messages(self, sync_engine, sync_chat_store): + messages = [ChatMessage(content="Message to delete", role="user")] + key = "test_delete_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_messages(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 0 + + async def test_delete_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Keep me", role="user") + message_2 = ChatMessage(content="Delete me", role="user") + messages = [message_1, message_2] + key = "test_delete_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_message(key, 1) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + assert results[0]["message"] == message_1.dict() + + async def test_delete_last_message(self, sync_engine, sync_chat_store): + message_1 = ChatMessage(content="Message 1", role="user") + message_2 = ChatMessage(content="Message 2", role="user") + message_3 = ChatMessage(content="Message 3", role="user") + messages = [message_1, message_2, message_3] + key = "test_delete_last_message_key" + sync_chat_store.set_messages(key, messages) + + sync_chat_store.delete_last_message(key) + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}' ORDER BY id;""" + results = await afetch(sync_engine, query) + + assert len(results) == 2 + assert results[0]["message"] == message_1.dict() + assert results[1]["message"] == message_2.dict() + + async def test_get_keys(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + message_2 = [ChatMessage(content="Second message", role="user")] + key_1 = "key1" + key_2 = "key2" + sync_chat_store.set_messages(key_1, message_1) + sync_chat_store.set_messages(key_2, message_2) + + keys = sync_chat_store.get_keys() + + assert key_1 in keys + assert key_2 in keys + + async def test_set_exisiting_key(self, sync_engine, sync_chat_store): + message_1 = [ChatMessage(content="First message", role="user")] + key = "test_set_exisiting_key" + sync_chat_store.set_messages(key, message_1) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + assert len(results) == 1 + result = results[0] + assert result["message"] == message_1[0].dict() + + message_2 = ChatMessage(content="Second message", role="user") + message_3 = ChatMessage(content="Third message", role="user") + messages = [message_2, message_3] + + sync_chat_store.set_messages(key, messages) + + query = f"""select * from "public"."{default_table_name_sync}" where key = '{key}';""" + results = await afetch(sync_engine, query) + + # Assert the previous messages are deleted and only the newest ones exist. + assert len(results) == 2 + + assert results[0]["message"] == message_2.dict() + assert results[1]["message"] == message_3.dict() From c65f3913709aada1300b364716872f421c68d5f3 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Wed, 8 Jan 2025 18:55:10 +0100 Subject: [PATCH 14/31] chore(deps): update python-nonmajor (#35) Co-authored-by: Averi Kitsch --- pyproject.toml | 4 ++-- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58adf9e..f14161c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ test = [ "black[jupyter]==24.10.0", "isort==5.13.2", - "mypy==1.13.0", - "pytest-asyncio==0.25.0", + "mypy==1.14.1", + "pytest-asyncio==0.25.2", "pytest==8.3.4", "pytest-cov==6.0.0" ] diff --git a/requirements.txt b/requirements.txt index 8fe0878..b1d16e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.6 +llama-index-core==0.12.10.post1 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.36 From f5e54dd6c2b8083c19402db6066d9ba5bcad4d01 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 03:07:20 +0530 Subject: [PATCH 15/31] chore: add code coverage (#32) * chore: add code coverage * fix: add tests to boost code coverage * fix: incorrect exception messages * fix: incorrect connection settings --------- Co-authored-by: Averi Kitsch --- .coveragerc | 8 +++ integration.cloudbuild.yaml | 2 +- tests/test_async_document_store.py | 82 ++++++++++++++++++++++++++++-- tests/test_async_index_store.py | 27 +++++++++- tests/test_async_vector_store.py | 72 ++++++++++++++++++++++++-- tests/test_document_store.py | 3 +- tests/test_engine.py | 73 +++++++++++++++++++++++--- tests/test_index_store.py | 10 +++- tests/test_vector_store.py | 10 ++-- 9 files changed, 264 insertions(+), 23 deletions(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b21412b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,8 @@ +[run] +branch = true +omit = + */__init__.py + +[report] +show_missing = true +fail_under = 90 diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index ffa5a05..769a42f 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -46,7 +46,7 @@ steps: - "-c" - | /workspace/cloud-sql-proxy ${_INSTANCE_CONNECTION_NAME} --port $_DATABASE_PORT & sleep 2; - python -m pytest tests/ + python -m pytest --cov=llama_index_cloud_sql_pg --cov-config=.coveragerc tests/ availableSecrets: secretManager: diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index d04db53..4c0dacb 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -28,6 +28,7 @@ default_table_name_async = "document_store_" + str(uuid.uuid4()) custom_table_name_async = "document_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresDocumentStore. Use PostgresDocumentStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -116,9 +117,16 @@ async def custom_doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresDocumentStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_warning(self, custom_doc_store): @@ -178,7 +186,7 @@ async def test_add_hash_before_data(self, async_engine, doc_store): result = results[0] assert result["node_data"][DATA_KEY]["text_resource"]["text"] == document_text - async def test_ref_doc_exists(self, doc_store): + async def test_aref_doc_exists(self, doc_store): # Create a ref_doc & a doc and add them to the store. ref_doc = Document( text="first doc", id_="doc_exists_doc_1", metadata={"doc": "info"} @@ -235,6 +243,8 @@ async def test_adelete_ref_doc(self, doc_store): assert ( await doc_store.aget_document(doc_id=doc.doc_id, raise_error=False) is None ) + # Confirm deleting an non-existent reference doc returns None. + assert await doc_store.adelete_ref_doc(ref_doc_id=ref_doc.doc_id) is None async def test_set_and_get_document_hash(self, doc_store): # Set a doc hash for a document @@ -245,6 +255,9 @@ async def test_set_and_get_document_hash(self, doc_store): # Assert with get that the hash is same as the one set. assert await doc_store.aget_document_hash(doc_id=doc_id) == doc_hash + async def test_aget_document_hash(self, doc_store): + assert await doc_store.aget_document_hash(doc_id="non-existent-doc") is None + async def test_set_and_get_document_hashes(self, doc_store): # Create a dictionary of doc_id -> doc_hash mappings and add it to the table. document_dict = { @@ -279,7 +292,7 @@ async def test_doc_store_basic(self, doc_store): retrieved_node = await doc_store.aget_document(doc_id=node.node_id) assert retrieved_node == node - async def test_delete_document(self, async_engine, doc_store): + async def test_adelete_document(self, async_engine, doc_store): # Create a doc and add it to the store. doc = Document(text="document_2", id_="doc_id_2", metadata={"doc": "info"}) await doc_store.async_add_documents([doc]) @@ -292,6 +305,11 @@ async def test_delete_document(self, async_engine, doc_store): result = await afetch(async_engine, query) assert len(result) == 0 + async def test_delete_non_existent_document(self, doc_store): + await doc_store.adelete_document(doc_id="non-existent-doc", raise_error=False) + with pytest.raises(ValueError): + await doc_store.adelete_document(doc_id="non-existent-doc") + async def test_doc_store_ref_doc_not_added(self, async_engine, doc_store): # Create a ref_doc & doc. ref_doc = Document( @@ -367,3 +385,61 @@ async def test_doc_store_delete_all_ref_doc_nodes(self, async_engine, doc_store) query = f"""select * from "public"."{default_table_name_async}" where id = '{ref_doc.doc_id}';""" result = await afetch(async_engine, query) assert len(result) == 0 + + async def test_docs(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.docs() + + async def test_add_documents(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.add_documents([]) + + async def test_get_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document("test_doc_id", raise_error=False) + + async def test_delete_document(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=True) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_document("test_doc_id", raise_error=False) + + async def test_document_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.document_exists("test_doc_id") + + async def test_ref_doc_exists(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.ref_doc_exists(ref_doc_id="test_ref_doc_id") + + async def test_set_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hash("test_doc_id", "test_doc_hash") + + async def test_set_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.set_document_hashes({"test_doc_id": "test_doc_hash"}) + + async def test_get_document_hash(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_document_hash(doc_id="test_doc_id") + + async def test_get_all_document_hashes(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_document_hashes() + + async def test_get_all_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_all_ref_doc_info() + + async def test_get_ref_doc_info(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.get_ref_doc_info(ref_doc_id="test_doc_id") + + async def test_delete_ref_doc(self, doc_store): + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=False) + with pytest.raises(Exception, match=sync_method_exception_str): + doc_store.delete_ref_doc(ref_doc_id="test_doc_id", raise_error=True) diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index 3736fda..d0d6f6c 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -26,6 +26,7 @@ from llama_index_cloud_sql_pg.async_index_store import AsyncPostgresIndexStore default_table_name_async = "index_store_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresIndexStore . Use PostgresIndexStore interface instead." async def aexecute(engine: PostgresEngine, query: str) -> None: @@ -102,9 +103,16 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): AsyncPostgresIndexStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async + ) + + async def test_create_without_table(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresIndexStore.create( + engine=async_engine, table_name="non-existent-table" ) async def test_add_and_delete_index(self, index_store, async_engine): @@ -162,3 +170,20 @@ async def test_warning(self, index_store): assert "No struct_id specified and more than one struct exists." in str( w[-1].message ) + + async def test_index_structs(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.index_structs() + + async def test_add_index_struct(self, index_store): + index_struct = IndexGraph() + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.add_index_struct(index_struct) + + async def test_delete_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.delete_index_struct("non_existent_key") + + async def test_get_index_struct(self, index_store): + with pytest.raises(Exception, match=sync_method_exception_str): + index_store.get_index_struct(struct_id="non_existent_id") diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index c38ceee..785a5b0 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -14,6 +14,7 @@ import os import uuid +import warnings from typing import Sequence import pytest @@ -109,8 +110,8 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() @pytest_asyncio.fixture(scope="class") @@ -153,8 +154,9 @@ async def custom_vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - AsyncPostgresVectorStore(engine, table_name=DEFAULT_TABLE) + AsyncPostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -313,6 +315,70 @@ async def test_aquery(self, engine, vs): assert len(results.nodes) == 3 assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + async def test_aquery_filters(self, engine, custom_vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + filters = MetadataFilters( + filters=[ + MetadataFilter( + key="some_test_column", + value=["value_should_be_ignored"], + operator=FilterOperator.CONTAINS, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.LTE, + ), + MetadataFilter( + key="len", + value=3, + operator=FilterOperator.GTE, + ), + MetadataFilter( + key="len", + value=2, + operator=FilterOperator.GT, + ), + MetadataFilter( + key="len", + value=4, + operator=FilterOperator.LT, + ), + MetadataFilters( + filters=[ + MetadataFilter( + key="len", + value=6.0, + operator=FilterOperator.NE, + ), + ], + condition=FilterCondition.OR, + ), + ], + condition=FilterCondition.AND, + ) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, filters=filters, similarity_top_k=-1 + ) + with warnings.catch_warnings(record=True) as w: + results = await custom_vs.aquery(query) + + assert len(w) == 1 + assert "Expecting a scalar in the filter value" in str(w[-1].message) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + async def test_aclear(self, engine, vs): # Note: To be migrated to a pytest dependency on test_adelete # Blocked due to unexpected fixtures reloads while running integration test suite diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 61d6786..b011dd5 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -117,9 +117,10 @@ async def doc_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): PostgresDocumentStore( - engine=async_engine, table_name=default_table_name_async + key, engine=async_engine, table_name=default_table_name_async ) async def test_async_add_document(self, async_engine, doc_store): diff --git a/tests/test_engine.py b/tests/test_engine.py index 46af5d0..f6df414 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -112,10 +112,11 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE}"') + + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE}"') await engine.close() async def test_password( @@ -359,12 +360,68 @@ async def engine(self, db_project, db_region, db_instance, db_name): database=db_name, ) yield engine - await aexecute(engine, f'DROP TABLE "{DEFAULT_DS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_IS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_VS_TABLE_SYNC}"') - await aexecute(engine, f'DROP TABLE "{DEFAULT_CS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_DS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE_SYNC}"') + await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + async def test_init_with_constructor( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + async def getconn() -> asyncpg.Connection: + conn = await connector.connect_async( # type: ignore + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=user, + password=password, + db=db_name, + enable_iam_auth=False, + ip_type=IPTypes.PUBLIC, + ) + return conn + + engine = create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + ) + + key = object() + with pytest.raises(Exception): + PostgresEngine(key, engine) + + async def test_missing_user_or_password( + self, + db_project, + db_region, + db_instance, + db_name, + user, + password, + ): + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + user=user, + ) + with pytest.raises(ValueError): + await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + password=password, + ) + async def test_password( self, db_project, diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 6df2017..5f840e7 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -111,8 +111,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_async) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_async + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() @@ -224,8 +227,11 @@ async def index_store(self, async_engine): await aexecute(async_engine, query) async def test_init_with_constructor(self, async_engine): + key = object() with pytest.raises(Exception): - PostgresIndexStore(engine=async_engine, table_name=default_table_name_sync) + PostgresIndexStore( + key, engine=async_engine, table_name=default_table_name_sync + ) async def test_add_and_delete_index(self, index_store, async_engine): index_struct = IndexGraph() diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index c365730..64fa303 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -117,7 +117,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() @pytest_asyncio.fixture(scope="class") @@ -129,8 +129,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" @@ -491,7 +492,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield sync_engine - await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() @pytest_asyncio.fixture(scope="class") @@ -503,8 +504,9 @@ async def vs(self, engine): yield vs async def test_init_with_constructor(self, engine): + key = object() with pytest.raises(Exception): - PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + PostgresVectorStore(key, engine, table_name=DEFAULT_TABLE) async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" From bcd4100c452422534b3a7c77118e629c9db18f88 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 04:56:07 +0530 Subject: [PATCH 16/31] chore: test cleanup (#43) * chore: remove pytest warning * chore: better engine cleanup * fix: remove params added to engine init * chore: remove engine connector.close when created from engine args * chore: remove engine connector.close from component classes --------- Co-authored-by: Averi Kitsch --- pyproject.toml | 5 ++++- tests/test_async_chat_store.py | 1 + tests/test_async_document_store.py | 1 + tests/test_async_index_store.py | 1 + tests/test_async_vector_store.py | 1 + tests/test_async_vector_store_index.py | 1 + tests/test_chat_store.py | 2 ++ tests/test_document_store.py | 2 ++ tests/test_engine.py | 4 ++++ tests/test_index_store.py | 2 ++ tests/test_vector_store.py | 2 ++ tests/test_vector_store_index.py | 2 ++ 12 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f14161c..b816c10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ test = [ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "class" + [tool.black] target-version = ['py39'] @@ -64,4 +67,4 @@ disallow_incomplete_defs = true exclude = [ 'docs/*', 'noxfile.py' -] \ No newline at end of file +] diff --git a/tests/test_async_chat_store.py b/tests/test_async_chat_store.py index dd22ac1..bdaf13f 100644 --- a/tests/test_async_chat_store.py +++ b/tests/test_async_chat_store.py @@ -92,6 +92,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def chat_store(self, async_engine): diff --git a/tests/test_async_document_store.py b/tests/test_async_document_store.py index 4c0dacb..d582ef4 100644 --- a/tests/test_async_document_store.py +++ b/tests/test_async_document_store.py @@ -90,6 +90,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): diff --git a/tests/test_async_index_store.py b/tests/test_async_index_store.py index d0d6f6c..0b7bbe8 100644 --- a/tests/test_async_index_store.py +++ b/tests/test_async_index_store.py @@ -88,6 +88,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py index 785a5b0..53752f0 100644 --- a/tests/test_async_vector_store.py +++ b/tests/test_async_vector_store.py @@ -113,6 +113,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_CUSTOM_VS}"') await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 1f392fc..9aaf7ed 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -99,6 +99,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_chat_store.py b/tests/test_chat_store.py index 305948c..694119b 100644 --- a/tests/test_chat_store.py +++ b/tests/test_chat_store.py @@ -96,6 +96,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def async_chat_store(self, async_engine): @@ -258,6 +259,7 @@ async def sync_engine(self, db_project, db_region, db_instance, db_name): yield sync_engine await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def sync_chat_store(self, sync_engine): diff --git a/tests/test_document_store.py b/tests/test_document_store.py index b011dd5..6432ecb 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -102,6 +102,7 @@ async def async_engine( yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def doc_store(self, async_engine): @@ -388,6 +389,7 @@ async def sync_engine( yield sync_engine await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def sync_doc_store(self, sync_engine): diff --git a/tests/test_engine.py b/tests/test_engine.py index f6df414..9c2b31d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -118,6 +118,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_IS_TABLE}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE}"') await engine.close() + await engine._connector.close_async() async def test_password( self, @@ -234,6 +235,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): await engine.ainit_doc_store_table( @@ -365,6 +367,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_VS_TABLE_SYNC}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_CS_TABLE_SYNC}"') await engine.close() + await engine._connector.close_async() async def test_init_with_constructor( self, @@ -471,6 +474,7 @@ async def test_iam_account_override( assert engine await aexecute(engine, "SELECT 1") await engine.close() + await engine._connector.close_async() async def test_init_document_store(self, engine): engine.init_doc_store_table( diff --git a/tests/test_index_store.py b/tests/test_index_store.py index 5f840e7..58bf057 100644 --- a/tests/test_index_store.py +++ b/tests/test_index_store.py @@ -96,6 +96,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): @@ -212,6 +213,7 @@ async def async_engine(self, db_project, db_region, db_instance, db_name): yield async_engine await async_engine.close() + await async_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def index_store(self, async_engine): diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py index 64fa303..a31bd2e 100644 --- a/tests/test_vector_store.py +++ b/tests/test_vector_store.py @@ -119,6 +119,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -494,6 +495,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') await sync_engine.close() + await sync_engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py index a0ac37c..e316f40 100644 --- a/tests/test_vector_store_index.py +++ b/tests/test_vector_store_index.py @@ -108,6 +108,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): @@ -198,6 +199,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") await engine.close() + await engine._connector.close_async() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): From d4054e1c39ac4a45e19eb1a061b5019a4594ec6f Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 14 Jan 2025 05:19:03 +0530 Subject: [PATCH 17/31] chore: add drop index statements to avoid conflict (#44) Co-authored-by: Averi Kitsch --- tests/test_async_vector_store_index.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py index 9aaf7ed..26dc4af 100644 --- a/tests/test_async_vector_store_index.py +++ b/tests/test_async_vector_store_index.py @@ -119,6 +119,7 @@ async def test_aapply_vector_index(self, vs): index = HNSWIndex() await vs.aapply_vector_index(index) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index(DEFAULT_INDEX_NAME) async def test_areindex(self, vs): if not await vs.is_valid_index(DEFAULT_INDEX_NAME): @@ -127,6 +128,7 @@ async def test_areindex(self, vs): await vs.areindex() await vs.areindex(DEFAULT_INDEX_NAME) assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + await vs.adrop_vector_index() async def test_dropindex(self, vs): await vs.adrop_vector_index() @@ -143,6 +145,7 @@ async def test_aapply_vector_index_ivfflat(self, vs): ) await vs.aapply_vector_index(index) assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index() await vs.adrop_vector_index("secondindex") async def test_is_valid_index(self, vs): From 7ceae89174ccb698f93457ef7fb6a73b4c1e683b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:50:13 -0800 Subject: [PATCH 18/31] chore(deps): bump virtualenv from 20.25.1 to 20.26.6 in /.kokoro (#45) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.25.1 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.25.1...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch --- .kokoro/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.kokoro/requirements.txt b/.kokoro/requirements.txt index 23e61f6..b5a1d9a 100644 --- a/.kokoro/requirements.txt +++ b/.kokoro/requirements.txt @@ -509,9 +509,9 @@ urllib3==2.2.2 \ # via # requests # twine -virtualenv==20.25.1 \ - --hash=sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a \ - --hash=sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197 +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 # via nox wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ From fadcd3f604892aa414b13ba3bb0de13524b704b7 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 14 Jan 2025 19:02:28 +0100 Subject: [PATCH 19/31] chore(deps): update dependency sqlalchemy to v2.0.37 (#42) Co-authored-by: Averi Kitsch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1d16e3..632ee76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.15.0 llama-index-core==0.12.10.post1 pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.36 +SQLAlchemy[asyncio]==2.0.37 From 5173e11831387909a12841bb232f8e39c113bd60 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:15:03 +0000 Subject: [PATCH 20/31] fix: query and return only selected metadata columns (#48) * fix: query and return only selected metadata columns * Review changes * Linter fix --- src/llama_index_cloud_sql_pg/async_vector_store.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py index 20baead..e864448 100644 --- a/src/llama_index_cloud_sql_pg/async_vector_store.py +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -531,7 +531,19 @@ async def __query_columns( f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" ) - query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + columns = self._metadata_columns + [ + self._id_column, + self._text_column, + self._embedding_column, + self._ref_doc_id_column, + self._node_column, + ] + if self._metadata_json_column: + columns.append(self._metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query_stmt = f'SELECT {column_names} {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' async with self._engine.connect() as conn: if self._index_query_options: query_options_stmt = ( From daa770f9a9d6f824a266ff5e98da6f89cf0d1713 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Fri, 17 Jan 2025 18:39:12 +0100 Subject: [PATCH 21/31] chore(deps): update python-nonmajor (#47) Co-authored-by: Averi Kitsch --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 632ee76..5e7f50c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -cloud-sql-python-connector[asyncpg]==1.15.0 -llama-index-core==0.12.10.post1 +cloud-sql-python-connector[asyncpg]==1.16.0 +llama-index-core==0.12.11 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From 22ec16b669e2838b58dc5d969fa922660a513cca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 09:59:20 -0800 Subject: [PATCH 22/31] chore(deps): bump virtualenv in /.kokoro/docker/docs (#49) Bumps [virtualenv](https://github.com/pypa/virtualenv) from 20.26.0 to 20.26.6. - [Release notes](https://github.com/pypa/virtualenv/releases) - [Changelog](https://github.com/pypa/virtualenv/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/virtualenv/compare/20.26.0...20.26.6) --- updated-dependencies: - dependency-name: virtualenv dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .kokoro/docker/docs/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.kokoro/docker/docs/requirements.txt b/.kokoro/docker/docs/requirements.txt index 56381b8..43b2594 100644 --- a/.kokoro/docker/docs/requirements.txt +++ b/.kokoro/docker/docs/requirements.txt @@ -32,7 +32,7 @@ platformdirs==4.2.1 \ --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 # via virtualenv -virtualenv==20.26.0 \ - --hash=sha256:0846377ea76e818daaa3e00a4365c018bc3ac9760cbb3544de542885aad61fb3 \ - --hash=sha256:ec25a9671a5102c8d2657f62792a27b48f016664c6873f6beed3800008577210 - # via nox \ No newline at end of file +virtualenv==20.26.6 \ + --hash=sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48 \ + --hash=sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2 + # via nox From 6765391d806a12a09ad83ed155f04f82aaccccf6 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 21 Jan 2025 18:16:20 +0100 Subject: [PATCH 23/31] chore(deps): update dependency llama-index-core to v0.12.12 (#50) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5e7f50c..d8f19ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cloud-sql-python-connector[asyncpg]==1.16.0 -llama-index-core==0.12.11 +llama-index-core==0.12.12 pgvector==0.3.6 SQLAlchemy[asyncio]==2.0.37 From d1f1598cd687dcf4d3a48a236cc007dc9e090b4e Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:27:36 +0000 Subject: [PATCH 24/31] chore(docs): Update docstring (#54) docs: Update docstring --- src/llama_index_cloud_sql_pg/chat_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama_index_cloud_sql_pg/chat_store.py b/src/llama_index_cloud_sql_pg/chat_store.py index db277e9..bb3b4ba 100644 --- a/src/llama_index_cloud_sql_pg/chat_store.py +++ b/src/llama_index_cloud_sql_pg/chat_store.py @@ -36,7 +36,7 @@ def __init__( Args: key (object): Key to prevent direct constructor usage. engine (PostgresEngine): Database connection pool. - chat_store (AsyncPostgresChatStore): The async only IndexStore implementation + chat_store (AsyncPostgresChatStore): The async only ChatStore implementation Raises: Exception: If constructor is directly called by the user. From 591600f13acac0ec7bf97ee3bc83041a99b3edec Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 12:04:20 +0000 Subject: [PATCH 25/31] feat: Add Async Postgres Reader (#52) * feat: Add Async Postgres Reader * Fix test * Linter fix * update iterator to iterable * change default metadata_json_column name * Add extra tests for sync methods. --- src/llama_index_cloud_sql_pg/async_reader.py | 270 ++++++++++ tests/test_async_reader.py | 494 +++++++++++++++++++ 2 files changed, 764 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/async_reader.py create mode 100644 tests/test_async_reader.py diff --git a/src/llama_index_cloud_sql_pg/async_reader.py b/src/llama_index_cloud_sql_pg/async_reader.py new file mode 100644 index 0000000..ccd68a2 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_reader.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, AsyncIterable, Callable, Iterable, Iterator, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +def text_formatter(row: dict, content_columns: list[str]) -> str: + """txt document formatter.""" + return " ".join(str(row[column]) for column in content_columns if column in row) + + +def csv_formatter(row: dict, content_columns: list[str]) -> str: + """CSV document formatter.""" + return ", ".join(str(row[column]) for column in content_columns if column in row) + + +def yaml_formatter(row: dict, content_columns: list[str]) -> str: + """YAML document formatter.""" + return "\n".join( + f"{column}: {str(row[column])}" for column in content_columns if column in row + ) + + +def json_formatter(row: dict, content_columns: list[str]) -> str: + """JSON document formatter.""" + dictionary = {} + for column in content_columns: + if column in row: + dictionary[column] = row[column] + return json.dumps(dictionary) + + +def _parse_doc_from_row( + content_columns: Iterable[str], + metadata_columns: Iterable[str], + row: dict, + formatter: Callable = text_formatter, + metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, +) -> Document: + """Parse row into document.""" + text = formatter(row, content_columns) + metadata: dict[str, Any] = {} + # unnest metadata from li_metadata column + if metadata_json_column and row.get(metadata_json_column): + for k, v in row[metadata_json_column].items(): + metadata[k] = v + # load metadata from other columns + for column in metadata_columns: + if column in row and column != metadata_json_column: + metadata[column] = row[column] + + return Document(text=text, extra_info=metadata) + + +class AsyncPostgresReader(BasePydanticReader): + """Load documents from Cloud SQL for PostgreSQL. + + Each document represents one row of the result. The `content_columns` are + written into the `text` of the document. The `metadata_columns` are written + into the `metadata` of the document. By default, first columns is written into + the `text` and everything else into the `metadata`. + """ + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + pool: AsyncEngine, + query: str, + content_columns: list[str], + metadata_columns: list[str], + formatter: Callable, + metadata_json_column: Optional[str] = None, + is_remote: bool = True, + ) -> None: + """AsyncPostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._pool = pool + self._query = query + self._content_columns = content_columns + self._metadata_columns = metadata_columns + self._formatter = formatter + self._metadata_json_column = metadata_json_column + + @classmethod + async def create( + cls: type[AsyncPostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> AsyncPostgresReader: + """Create an AsyncPostgresReader instance. + + Args: + engine (PostgresEngine):AsyncEngine with pool connection to the Cloud SQL Postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (bool): Whether the data is loaded from a remote API or a local file. + + + Returns: + AsyncPostgresReader: A newly created instance of AsyncPostgresReader. + """ + if table_name and query: + raise ValueError("Only one of 'table_name' or 'query' should be specified.") + if not table_name and not query: + raise ValueError( + "At least one of the parameters 'table_name' or 'query' needs to be provided" + ) + if format and formatter: + raise ValueError("Only one of 'format' or 'formatter' should be specified.") + + if format and format not in ["csv", "text", "JSON", "YAML"]: + raise ValueError("format must be type: 'csv', 'text', 'JSON', 'YAML'") + if formatter: + formatter = formatter + elif format == "csv": + formatter = csv_formatter + elif format == "YAML": + formatter = yaml_formatter + elif format == "JSON": + formatter = json_formatter + else: + formatter = text_formatter + + if not query: + query = f'SELECT * FROM "{schema_name}"."{table_name}"' + + async with engine._pool.connect() as connection: + result_proxy = await connection.execute(text(query)) + column_names = list(result_proxy.keys()) + # Select content or default to first column + content_columns = content_columns or [column_names[0]] + # Select metadata columns + metadata_columns = metadata_columns or [ + col for col in column_names if col not in content_columns + ] + + # Check validity of metadata json column + if metadata_json_column and metadata_json_column not in column_names: + raise ValueError( + f"Column {metadata_json_column} not found in query result {column_names}." + ) + + if metadata_json_column and metadata_json_column in column_names: + metadata_json_column = metadata_json_column + elif DEFAULT_METADATA_COL in column_names: + metadata_json_column = DEFAULT_METADATA_COL + else: + metadata_json_column = None + + # check validity of other column + all_names = content_columns + metadata_columns + for name in all_names: + if name not in column_names: + raise ValueError( + f"Column {name} not found in query result {column_names}." + ) + return cls( + key=cls.__create_key, + pool=engine._pool, + query=query, + content_columns=content_columns, + metadata_columns=metadata_columns, + formatter=formatter, + metadata_json_column=metadata_json_column, + is_remote=is_remote, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncPostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL Postgres data into Document objects.""" + return [doc async for doc in self.alazy_load_data()] + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL Postgres data into Document objects lazily.""" + async with self._pool.connect() as connection: + result_proxy = await connection.execute(text(self._query)) + # load document one by one + while True: + row = result_proxy.fetchone() + if not row: + break + + row_data = {} + column_names = self._content_columns + self._metadata_columns + column_names += ( + [self._metadata_json_column] if self._metadata_json_column else [] + ) + for column in column_names: + value = getattr(row, column) + row_data[column] = value + + yield _parse_doc_from_row( + self._content_columns, + self._metadata_columns, + row_data, + self._formatter, + self._metadata_json_column, + ) + + def lazy_load_data(self) -> Iterable[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) + + def load_data(self) -> List[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + ) diff --git a/tests/test_async_reader.py b/tests/test_async_reader.py new file mode 100644 index 0000000..6e2a665 --- /dev/null +++ b/tests/test_async_reader.py @@ -0,0 +1,494 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_reader import AsyncPostgresReader + +default_table_name_async = "reader_test_" + str(uuid.uuid4()) +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresReader. Use PostgresReader interface instead." + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncPostgresReader: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + await self._create_default_table(async_engine) + + yield async_engine + + await self._cleanup_table(async_engine) + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _create_default_table(self, engine): + create_query = f""" + CREATE TABLE IF NOT EXISTS "{default_table_name_async}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(engine, create_query) + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_lazy_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.lazy_load_data() + + async def test_load_data(self, async_engine): + with pytest.raises(Exception, match=sync_method_exception_str): + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + ) + + reader.load_data() + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await AsyncPostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') From 4ceade46a00980d2e75d03fde11b8a1f888dfc25 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 12:15:34 +0000 Subject: [PATCH 26/31] feat: Add Postgres Reader (#53) * feat: Add Postgres Reader * Linter fix * Change metadata_json_column default value * Add method comment about type mismatch --- src/llama_index_cloud_sql_pg/__init__.py | 2 + src/llama_index_cloud_sql_pg/reader.py | 187 +++++ tests/test_reader.py | 900 +++++++++++++++++++++++ 3 files changed, 1089 insertions(+) create mode 100644 src/llama_index_cloud_sql_pg/reader.py create mode 100644 tests/test_reader.py diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index 4a367b5..e669eac 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -16,6 +16,7 @@ from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore +from .reader import PostgresReader from .vector_store import PostgresVectorStore from .version import __version__ @@ -25,6 +26,7 @@ "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", + "PostgresReader", "PostgresVectorStore", "__version__", ] diff --git a/src/llama_index_cloud_sql_pg/reader.py b/src/llama_index_cloud_sql_pg/reader.py new file mode 100644 index 0000000..374094a --- /dev/null +++ b/src/llama_index_cloud_sql_pg/reader.py @@ -0,0 +1,187 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import AsyncIterable, Callable, Iterable, List, Optional + +from llama_index.core.bridge.pydantic import ConfigDict +from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.schema import Document + +from .async_reader import AsyncPostgresReader +from .engine import PostgresEngine + +DEFAULT_METADATA_COL = "li_metadata" + + +class PostgresReader(BasePydanticReader): + """Chat Store Table stored in an Cloud SQL for PostgreSQL database.""" + + __create_key = object() + is_remote: bool = True + + def __init__( + self, + key: object, + engine: PostgresEngine, + reader: AsyncPostgresReader, + is_remote: bool = True, + ) -> None: + """PostgresReader constructor. + + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + reader (AsyncPostgresReader): The async only PostgresReader implementation + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + Raises: + Exception: If called directly by user. + """ + if key != PostgresReader.__create_key: + raise Exception("Only create class through 'create' method!") + + super().__init__(is_remote=is_remote) + + self._engine = engine + self.__reader = reader + + @classmethod + async def create( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Asynchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def create_sync( + cls: type[PostgresReader], + engine: PostgresEngine, + query: Optional[str] = None, + table_name: Optional[str] = None, + schema_name: str = "public", + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, + metadata_json_column: Optional[str] = None, + format: Optional[str] = None, + formatter: Optional[Callable] = None, + is_remote: bool = True, + ) -> PostgresReader: + """Synchronously create an PostgresReader instance. + + Args: + engine (PostgresEngine): PostgresEngine with pool connection to the Cloud SQL postgres database + query (Optional[str], optional): SQL query. Defaults to None. + table_name (Optional[str], optional): Name of table to query. Defaults to None. + schema_name (str, optional): Name of the schema where table is located. Defaults to "public". + content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column. + metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None. + metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "li_metadata". + format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'. + formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None. + is_remote (Optional[bool]): Whether the data is loaded from a remote API or a local file. + + + Returns: + PostgresReader: A newly created instance of PostgresReader. + """ + coro = AsyncPostgresReader.create( + engine=engine, + query=query, + table_name=table_name, + schema_name=schema_name, + content_columns=content_columns, + metadata_columns=metadata_columns, + metadata_json_column=metadata_json_column, + format=format, + formatter=formatter, + is_remote=is_remote, + ) + reader = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, reader, is_remote) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "PostgresReader" + + async def aload_data(self) -> list[Document]: + """Asynchronously load Cloud SQL postgres data into Document objects.""" + return await self._engine._run_as_async(self.__reader.aload_data()) + + def load_data(self) -> list[Document]: + """Synchronously load Cloud SQL postgres data into Document objects.""" + return self._engine._run_as_sync(self.__reader.aload_data()) + + async def alazy_load_data(self) -> AsyncIterable[Document]: # type: ignore + """Asynchronously load Cloud SQL postgres data into Document objects lazily.""" + # The return type in the underlying base class is an Iterable which we are overriding to an AsyncIterable in this implementation. + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = await self._engine._run_as_async(iterator.__anext__()) + yield result + except StopAsyncIteration: + break + + def lazy_load_data(self) -> Iterable[Document]: # type: ignore + """Synchronously load Cloud SQL postgres data into Document objects lazily.""" + iterator = self.__reader.alazy_load_data().__aiter__() + while True: + try: + result = self._engine._run_as_sync(iterator.__anext__()) + yield result + except StopAsyncIteration: + break diff --git a/tests/test_reader.py b/tests/test_reader.py new file mode 100644 index 0000000..fe5b50b --- /dev/null +++ b/tests/test_reader.py @@ -0,0 +1,900 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import uuid +from typing import Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import Document +from sqlalchemy import RowMapping, text + +from llama_index_cloud_sql_pg import PostgresEngine, PostgresReader + +default_table_name_async = "async_reader_test_" + str(uuid.uuid4()) +default_table_name_sync = "sync_reader_test_" + str(uuid.uuid4()) + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield async_engine + + await aexecute( + async_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await async_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + async def _collect_async_items(self, docs_generator): + """Collects items from an async generator.""" + docs = [] + async for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, async_engine): + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + await PostgresReader.create( + engine=async_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, async_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + table_name=table_name, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, async_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(async_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, async_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, async_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(async_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(async_engine, insert_query) + + reader = await PostgresReader.create( + engine=async_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = await self._collect_async_items(reader.alazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + +@pytest.mark.asyncio(loop_scope="class") +class TestPostgresReaderSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(scope="class") + async def sync_engine( + self, + db_project, + db_region, + db_instance, + db_name, + ): + sync_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield sync_engine + + await aexecute( + sync_engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"' + ) + + await sync_engine.close() + + async def _cleanup_table(self, engine): + await aexecute(engine, f'DROP TABLE IF EXISTS "{default_table_name_async}"') + + def _collect_items(self, docs_generator): + """Collects items from a generator.""" + docs = [] + for doc in docs_generator: + docs.append(doc) + return docs + + async def test_create_reader_with_invalid_parameters(self, sync_engine): + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + ) + with pytest.raises(ValueError): + + def fake_formatter(): + return None + + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="text", + formatter=fake_formatter, + ) + with pytest.raises(ValueError): + PostgresReader.create_sync( + engine=sync_engine, + table_name=default_table_name_async, + format="fake_format", + ) + + async def test_load_from_query_default(self, sync_engine): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" ( + fruit_name, variety, quantity_in_stock, price_per_unit, organic + ) VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + table_name=table_name, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_document = Document( + text="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + + assert documents[0].text == expected_document.text + assert documents[0].metadata == expected_document.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_customized_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + expected_docs = [ + Document( + text="Apple Smith 150 1 1", + metadata={"fruit_id": 1}, + ), + Document( + text="Banana Cavendish 200 1 0", + metadata={"fruit_id": 2}, + ), + Document( + text="Orange Navel 80 1 1", + metadata={"fruit_id": 3}, + ), + ] + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + # Compare the full list of documents to make sure all are in sync. + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_text_docs = [ + Document( + text="Granny Smith 150 1", + metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1}, + ) + ] + + for expected, actual in zip(expected_text_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="JSON", + ) + + actual_documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text='{"variety": "Granny Smith", "quantity_in_stock": 150, "price_per_unit": 1}', + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, actual_documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_with_json(self, sync_engine): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}"( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety JSON NOT NULL, + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + li_metadata JSON NOT NULL + ) + """ + await aexecute(sync_engine, query) + + metadata = json.dumps({"organic": 1}) + variety = json.dumps({"type": "Granny Smith"}) + insert_query = f""" + INSERT INTO "{table_name}" + (fruit_name, variety, quantity_in_stock, price_per_unit, li_metadata) + VALUES ('Apple', '{variety}', 150, 1, '{metadata}');""" + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + metadata_columns=[ + "variety", + ], + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="1", + metadata={ + "variety": {"type": "Granny Smith"}, + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_formatter( + self, sync_engine + ): + + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + def my_formatter(row, content_columns): + return "-".join( + str(row[column]) for column in content_columns if column in row + ) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + formatter=my_formatter, + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_documents = [ + Document( + text="Granny Smith-150-1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_documents, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') + + async def test_load_from_query_customized_content_default_metadata_custom_page_content_format( + self, sync_engine + ): + table_name = "test-table" + str(uuid.uuid4()) + query = f""" + CREATE TABLE IF NOT EXISTS "{table_name}" ( + fruit_id SERIAL PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit INT NOT NULL, + organic INT NOT NULL + ) + """ + await aexecute(sync_engine, query) + + insert_query = f""" + INSERT INTO "{table_name}" (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES ('Apple', 'Granny Smith', 150, 1, 1); + """ + await aexecute(sync_engine, insert_query) + + reader = PostgresReader.create_sync( + engine=sync_engine, + query=f'SELECT * FROM "{table_name}";', + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + format="YAML", + ) + + documents = self._collect_items(reader.lazy_load_data()) + + expected_docs = [ + Document( + text="variety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 1", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + for expected, actual in zip(expected_docs, documents): + assert expected.text == actual.text + assert expected.metadata == actual.metadata + + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{table_name}"') From a027df90fcfd4cc301bb584855aa97f7e5ef9e66 Mon Sep 17 00:00:00 2001 From: dishaprakash <57954147+dishaprakash@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:29:19 +0000 Subject: [PATCH 27/31] chore(docs): Fix minor typos in sample notebooks (#60) * chore(docs): Fix minor typos in sample notebooks * chore(docs): Fix minor typos in sample notebooks --- samples/llama_index_doc_store.ipynb | 2 +- samples/llama_index_vector_store.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/samples/llama_index_doc_store.ipynb b/samples/llama_index_doc_store.ipynb index 8c9e78b..c15ef8e 100644 --- a/samples/llama_index_doc_store.ipynb +++ b/samples/llama_index_doc_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresDocumentStore` & `PostgresIndexStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store documents and indexes with the `PostgresDocumentStore` and `PostgresIndexStore` classes.\n", "\n", diff --git a/samples/llama_index_vector_store.ipynb b/samples/llama_index_vector_store.ipynb index fd8cc3e..e482cd1 100644 --- a/samples/llama_index_vector_store.ipynb +++ b/samples/llama_index_vector_store.ipynb @@ -8,7 +8,7 @@ "source": [ "# Google Cloud SQL for PostgreSQL - `PostgresVectorStore`\n", "\n", - "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers PostgreSQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", + "> [Cloud SQL](https://cloud.google.com/sql) is a fully managed relational database service that offers high performance, seamless integration, and impressive scalability. It offers MySQL, PostgreSQL, and SQL Server database engines. Extend your database application to build AI-powered experiences leveraging Cloud SQL's LlamaIndex integrations.\n", "\n", "This notebook goes over how to use `Cloud SQL for PostgreSQL` to store vector embeddings with the `PostgresVectorStore` class.\n", "\n", From 4e0f16c7c241ce83e69745bdb00697458f2be6e8 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:22:47 +0100 Subject: [PATCH 28/31] chore(deps): update actions/setup-python action to v5.4.0 (#56) --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 28d3312..fbd4535 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,7 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 with: python-version: "3.11" From cd8d15bfa62d74a0ed272d7e9502f2915303fa1e Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:33:43 +0100 Subject: [PATCH 29/31] chore(deps): update dependency isort to v6 (#55) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b816c10..e5b564d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ "black[jupyter]==24.10.0", - "isort==5.13.2", + "isort==6.0.0", "mypy==1.14.1", "pytest-asyncio==0.25.2", "pytest==8.3.4", From 1ec027390cc1afdd6d8d61ff8a8279165d71d875 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 30 Jan 2025 20:43:20 +0100 Subject: [PATCH 30/31] chore(deps): update dependency black to v25 (#57) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5b564d..d59c095 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ Changelog = "https://github.com/googleapis/llama-index-cloud-sql-pg-python/blob/ [project.optional-dependencies] test = [ - "black[jupyter]==24.10.0", + "black[jupyter]==25.1.0", "isort==6.0.0", "mypy==1.14.1", "pytest-asyncio==0.25.2", From 458c07294e65ba66a2a40efb2e723576a02a0a79 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:41:10 -0800 Subject: [PATCH 31/31] chore(main): release 0.2.0 (#28) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 16 ++++++++++++++++ src/llama_index_cloud_sql_pg/version.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29b89fe..d198894 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [0.2.0](https://github.com/googleapis/llama-index-cloud-sql-pg-python/compare/v0.1.0...v0.2.0) (2025-01-30) + + +### Features + +* Add Async Chat Store ([#38](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/38)) ([2b14f5a](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/2b14f5a946e595bce145bf1b526138cf393250ed)) +* Add Async Postgres Reader ([#52](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/52)) ([591600f](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/591600f13acac0ec7bf97ee3bc83041a99b3edec)) +* Add chat store init methods ([#39](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/39)) ([0ef1fa5](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/0ef1fa5c945c9012354fc6cacb4fc50dd12c0c19)) +* Add Postgres Chat Store ([#40](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/40)) ([7787d7d](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/7787d7d1161dd994c11ac8a75eb5890cf9309cee)) +* Add Postgres Reader ([#53](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/53)) ([4ceade4](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/4ceade46a00980d2e75d03fde11b8a1f888dfc25)) + + +### Bug Fixes + +* Query and return only selected metadata columns ([#48](https://github.com/googleapis/llama-index-cloud-sql-pg-python/issues/48)) ([5173e11](https://github.com/googleapis/llama-index-cloud-sql-pg-python/commit/5173e11831387909a12841bb232f8e39c113bd60)) + ## 0.1.0 (2024-12-03) diff --git a/src/llama_index_cloud_sql_pg/version.py b/src/llama_index_cloud_sql_pg/version.py index c1c8212..20c5861 100644 --- a/src/llama_index_cloud_sql_pg/version.py +++ b/src/llama_index_cloud_sql_pg/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +__version__ = "0.2.0"