diff --git a/CHANGELOG.md b/CHANGELOG.md index b05ddb37..270dc017 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## [0.14.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.13.0...v0.14.0) (2025-04-29) + + +### Features + +* Update Postgres VectorStore to expected LangChain functionality ([#290](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/290)) ([605c31d](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/605c31dad924d359e2de0bc26476c5830a0cba69)) + + +### Bug Fixes + +* **docs:** Fix link in README ([#293](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/293)) ([6bfc58c](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/6bfc58cd7d546cb454ff2a18c74cec287fc764cb)) +* Update JSON conversion ([#296](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/issues/296)) ([4313ba2](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/commit/4313ba234f2328537ed9401c1152d78c1ab71440)) + ## [0.13.0](https://github.com/googleapis/langchain-google-cloud-sql-pg-python/compare/v0.12.1...v0.13.0) (2025-03-17) diff --git a/README.rst b/README.rst index c5da0942..6833433b 100644 --- a/README.rst +++ b/README.rst @@ -172,10 +172,83 @@ Use ``PostgresSaver`` to save snapshots of the graph state at a given point in t See the full `Checkpoint`_ tutorial. -.. _`Checkpoint`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/docs/langgraph_checkpointer.ipynb +.. _`Checkpoint`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/blob/main/docs/langgraph_checkpoint.ipynb + +Example Usage +------------- + +Code examples can be found in the `samples/`_ folder. + +.. _samples/: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/samples + +Converting between Sync & Async Usage +------------------------------------- + +Async functionality improves the speed and efficiency of database connections through concurrency, +which is key for providing enterprise quality performance and scaling in GenAI applications. This +package uses a native async Postgres driver, `asyncpg`_, to optimize Python's async functionality. + +LangChain supports `async programming`_, since LLM based application utilize many I/O-bound operations, +such as making API calls to language models, databases, or other services. All components should provide +both async and sync versions of all methods. + +`asyncio`_ is a Python library used for concurrent programming and is used as the foundation for multiple +Python asynchronous frameworks. asyncio uses `async` / `await` syntax to achieve concurrency for +non-blocking I/O-bound tasks using one thread with cooperative multitasking instead of multi-threading. + +.. _`async programming`: https://python.langchain.com/docs/concepts/async/ +.. _`asyncio`: https://docs.python.org/3/library/asyncio.html +.. _`asyncpg`: https://github.com/MagicStack/asyncpg + +Converting Sync to Async +~~~~~~~~~~~~~~~~~~~~~~~~ + +Update sync methods to `await` async methods + +.. code:: python + + engine = await PostgresEngine.afrom_instance("project-id", "region", "my-instance", "my-database") + await engine.ainit_vectorstore_table(table_name="my-table", vector_size=768) + vectorstore = await PostgresVectorStore.create( + engine, + table_name="my-table", + embedding_service=VertexAIEmbeddings(model_name="textembedding-gecko@003") + ) + +Run the code: notebooks +^^^^^^^^^^^^^^^^^^^^^^^ + +ipython and jupyter notebooks support the use of the `await` keyword without any additional setup + +Run the code: FastAPI +^^^^^^^^^^^^^^^^^^^^^ + +Update routes to use `async def`. + +.. code:: python + + @app.get("/invoke/") + async def invoke(query: str): + return await retriever.ainvoke(query) + + +Run the code: Local python file +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is recommend to create a top-level async method definition: `async def` to wrap multiple async methods. +Then use `asyncio.run()` to run the the top-level entrypoint, e.g. "main()" + +.. code:: python + + async def main(): + response = await retriever.ainvoke(query) + print(response) + + asyncio.run(main()) + Contributions -~~~~~~~~~~~~~ +------------- Contributions to this library are always welcome and highly encouraged. @@ -188,7 +261,14 @@ information. .. _`CONTRIBUTING`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/CONTRIBUTING.md .. _`Code of Conduct`: https://github.com/googleapis/langchain-google-cloud-sql-pg-python/tree/main/CODE_OF_CONDUCT.md +License +------- + +Apache 2.0 - See +`LICENSE `_ +for more information. + Disclaimer -~~~~~~~~~~~ +---------- This is not an officially supported Google product. diff --git a/pyproject.toml b/pyproject.toml index 63bb192d..5e5a3ec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,11 @@ test = [ "black[jupyter]==25.1.0", "isort==6.0.1", "mypy==1.15.0", - "pytest-asyncio==0.25.3", - "pytest==8.3.4", - "pytest-cov==6.0.0", - "langgraph==0.2.74" + "pytest-asyncio==0.26.0", + "pytest==8.3.5", + "pytest-cov==6.1.1", + "langchain-tests==0.3.19", + "langgraph==0.3.31" ] [build-system] diff --git a/requirements.txt b/requirements.txt index c52a1b3d..31d4c8af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -cloud-sql-python-connector[asyncpg]==1.17.0 -langchain-core==0.3.40 -numpy==2.2.3; python_version > "3.9" -numpy== 2.0.2; python_version <= "3.9" -pgvector==0.3.6 -SQLAlchemy[asyncio]==2.0.38 -langgraph-checkpoint==2.0.10 \ No newline at end of file +cloud-sql-python-connector[asyncpg]==1.18.1 +langchain-core==0.3.55 +numpy==2.2.5; python_version > "3.9" +numpy==2.0.2; python_version <= "3.9" +pgvector==0.4.0 +SQLAlchemy[asyncio]==2.0.40 +langgraph==0.3.31 diff --git a/samples/index_tuning_sample/requirements.txt b/samples/index_tuning_sample/requirements.txt index 19489979..e2124edf 100644 --- a/samples/index_tuning_sample/requirements.txt +++ b/samples/index_tuning_sample/requirements.txt @@ -1,3 +1,3 @@ -langchain-community==0.3.18 -langchain-google-cloud-sql-pg==0.12.1 -langchain-google-vertexai==2.0.14 +langchain-community==0.3.22 +langchain-google-cloud-sql-pg==0.13.0 +langchain-google-vertexai==2.0.20 diff --git a/samples/langchain_on_vertexai/requirements.txt b/samples/langchain_on_vertexai/requirements.txt index 153755af..f823fca7 100644 --- a/samples/langchain_on_vertexai/requirements.txt +++ b/samples/langchain_on_vertexai/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.81.0 -google-cloud-resource-manager==1.14.1 -langchain-community==0.3.18 -langchain-google-cloud-sql-pg==0.12.1 -langchain-google-vertexai==2.0.14 +google-cloud-aiplatform[reasoningengine,langchain]==1.89.0 +google-cloud-resource-manager==1.14.2 +langchain-community==0.3.22 +langchain-google-cloud-sql-pg==0.13.0 +langchain-google-vertexai==2.0.20 diff --git a/samples/requirements.txt b/samples/requirements.txt index 153755af..f823fca7 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform[reasoningengine,langchain]==1.81.0 -google-cloud-resource-manager==1.14.1 -langchain-community==0.3.18 -langchain-google-cloud-sql-pg==0.12.1 -langchain-google-vertexai==2.0.14 +google-cloud-aiplatform[reasoningengine,langchain]==1.89.0 +google-cloud-resource-manager==1.14.2 +langchain-community==0.3.22 +langchain-google-cloud-sql-pg==0.13.0 +langchain-google-vertexai==2.0.20 diff --git a/src/langchain_google_cloud_sql_pg/async_checkpoint.py b/src/langchain_google_cloud_sql_pg/async_checkpoint.py index 560182f7..fc875991 100644 --- a/src/langchain_google_cloud_sql_pg/async_checkpoint.py +++ b/src/langchain_google_cloud_sql_pg/async_checkpoint.py @@ -220,7 +220,7 @@ def _search_where( # construct predicate for metadata filter if filter: - wheres.append("encode(metadata,'escape')::jsonb @> :metadata ") + wheres.append("convert_from(metadata,'UTF8')::jsonb @> :metadata ") param_values.update({"metadata": f"{json.dumps(filter)}"}) # construct predicate for `before` diff --git a/src/langchain_google_cloud_sql_pg/async_vectorstore.py b/src/langchain_google_cloud_sql_pg/async_vectorstore.py index b8884438..7a0c3217 100644 --- a/src/langchain_google_cloud_sql_pg/async_vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/async_vectorstore.py @@ -15,6 +15,7 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations +import copy import json import uuid from typing import Any, Callable, Iterable, Optional, Sequence @@ -37,6 +38,36 @@ QueryOptions, ) +COMPARISONS_TO_NATIVE = { + "$eq": "=", + "$ne": "!=", + "$lt": "<", + "$lte": "<=", + "$gt": ">", + "$gte": ">=", +} + +SPECIAL_CASED_OPERATORS = { + "$in", + "$nin", + "$between", + "$exists", +} + +TEXT_OPERATORS = { + "$like", + "$ilike", +} + +LOGICAL_OPERATORS = {"$and", "$or", "$not"} + +SUPPORTED_OPERATORS = ( + set(COMPARISONS_TO_NATIVE) + .union(TEXT_OPERATORS) + .union(LOGICAL_OPERATORS) + .union(SPECIAL_CASED_OPERATORS) +) + class AsyncPostgresVectorStore(VectorStore): """Google Cloud SQL for PostgreSQL Vector Store class""" @@ -236,6 +267,9 @@ async def __aadd_embeddings( """ if not ids: ids = [str(uuid.uuid4()) for _ in texts] + else: + # This is done to fill in any missing ids + ids = [id if id is not None else str(uuid.uuid4()) for id in ids] if not metadatas: metadatas = [{} for _ in texts] # Insert embeddings @@ -254,7 +288,7 @@ async def __aadd_embeddings( values_stmt = "VALUES (:id, :content, :embedding" # Add metadata - extra = metadata + extra = copy.deepcopy(metadata) for metadata_column in self.metadata_columns: if metadata_column in metadata: values_stmt += f", :{metadata_column}" @@ -275,13 +309,66 @@ async def __aadd_embeddings( else: values_stmt += ")" - query = insert_stmt + values_stmt + upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"' + + if self.metadata_json_column: + upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"' + + for column in self.metadata_columns: + upsert_stmt += f', "{column}" = EXCLUDED."{column}"' + + upsert_stmt += ";" + + query = insert_stmt + values_stmt + upsert_stmt async with self.pool.connect() as conn: await conn.execute(text(query), values) await conn.commit() return ids + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + + quoted_ids = [f"'{id_val}'" for id_val in ids] + id_list_str = ", ".join(quoted_ids) + + columns = self.metadata_columns + [ + self.id_column, + self.content_column, + ] + if self.metadata_json_column: + columns.append(self.metadata_json_column) + + column_names = ", ".join(f'"{col}"' for col in columns) + + query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});' + + async with self.pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + + documents = [] + for row in results: + metadata = ( + row[self.metadata_json_column] + if self.metadata_json_column and row[self.metadata_json_column] + else {} + ) + for col in self.metadata_columns: + metadata[col] = row[col] + documents.append( + ( + Document( + page_content=row[self.content_column], + metadata=metadata, + id=str(row[self.id_column]), + ) + ) + ) + + return documents + async def aadd_texts( self, texts: Iterable[str], @@ -313,6 +400,8 @@ async def aadd_documents( """ texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] + if not ids: + ids = [doc.id for doc in documents] ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return ids @@ -483,7 +572,7 @@ async def __query_collection( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> Sequence[RowMapping]: """Perform similarity search query on the vector store table.""" @@ -500,6 +589,9 @@ async def __query_collection( columns.append(self.metadata_json_column) column_names = ", ".join(f'"{col}"' for col in columns) + + if filter and isinstance(filter, dict): + filter = self._create_filter_clause(filter) filter = f"WHERE {filter}" if filter else "" embedding_string = f"'{[float(dimension) for dimension in embedding]}'" stmt = f'SELECT {column_names}, {search_function}({self.embedding_column}, {embedding_string}) as distance FROM "{self.schema_name}"."{self.table_name}" {filter} ORDER BY {self.embedding_column} {operator} {embedding_string} LIMIT {k};' @@ -522,7 +614,7 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -547,7 +639,7 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -561,7 +653,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -575,7 +667,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -597,6 +689,7 @@ async def asimilarity_search_with_score_by_vector( Document( page_content=row[self.content_column], metadata=metadata, + id=str(row[self.id_column]), ), row["distance"], ) @@ -610,7 +703,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -631,7 +724,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -654,7 +747,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -687,6 +780,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( Document( page_content=row[self.content_column], metadata=metadata, + id=str(row[self.id_column]), ), row["distance"], ) @@ -759,11 +853,204 @@ async def is_valid_index( return bool(len(results) == 1) + def _handle_field_filter( + self, + field: str, + value: Any, + ) -> str: + """Create a filter for a specific field. + Args: + field: name of field + value: value to filter + If provided as is then this will be an equality filter + If provided as a dictionary then this will be a filter, the key + will be the operator and the value will be the value to filter by + Returns: + sql where query as a string + """ + if not isinstance(field, str): + raise ValueError( + f"field should be a string but got: {type(field)} with value: {field}" + ) + + if field.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got an operator: " + f"{field}" + ) + + # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters + if not field.isidentifier(): + raise ValueError( + f"Invalid field name: {field}. Expected a valid identifier." + ) + + if isinstance(value, dict): + # This is a filter specification + if len(value) != 1: + raise ValueError( + "Invalid filter condition. Expected a value which " + "is a dictionary with a single key that corresponds to an operator " + f"but got a dictionary with {len(value)} keys. The first few " + f"keys are: {list(value.keys())[:3]}" + ) + operator, filter_value = list(value.items())[0] + # Verify that that operator is an operator + if operator not in SUPPORTED_OPERATORS: + raise ValueError( + f"Invalid operator: {operator}. " + f"Expected one of {SUPPORTED_OPERATORS}" + ) + else: # Then we assume an equality operator + operator = "$eq" + filter_value = value + + if operator in COMPARISONS_TO_NATIVE: + # Then we implement an equality filter + # native is trusted input + if isinstance(filter_value, str): + filter_value = f"'{filter_value}'" + native = COMPARISONS_TO_NATIVE[operator] + return f"({field} {native} {filter_value})" + elif operator == "$between": + # Use AND with two comparisons + low, high = filter_value + + return f"({field} BETWEEN {low} AND {high})" + elif operator in {"$in", "$nin", "$like", "$ilike"}: + # We'll do force coercion to text + if operator in {"$in", "$nin"}: + for val in filter_value: + if not isinstance(val, (str, int, float)): + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + if isinstance(val, bool): # b/c bool is an instance of int + raise NotImplementedError( + f"Unsupported type: {type(val)} for value: {val}" + ) + + if operator in {"$in"}: + values = str(tuple(val for val in filter_value)) + return f"({field} IN {values})" + elif operator in {"$nin"}: + values = str(tuple(val for val in filter_value)) + return f"({field} NOT IN {values})" + elif operator in {"$like"}: + return f"({field} LIKE '{filter_value}')" + elif operator in {"$ilike"}: + return f"({field} ILIKE '{filter_value}')" + else: + raise NotImplementedError() + elif operator == "$exists": + if not isinstance(filter_value, bool): + raise ValueError( + "Expected a boolean value for $exists " + f"operator, but got: {filter_value}" + ) + else: + if filter_value: + return f"({field} IS NOT NULL)" + else: + return f"({field} IS NULL)" + else: + raise NotImplementedError() + + def _create_filter_clause(self, filters: Any) -> str: + """Create LangChain filter representation to matching SQL where clauses + Args: + filters: Dictionary of filters to apply to the query. + Returns: + String containing the sql where query. + """ + + if not isinstance(filters, dict): + raise ValueError( + f"Invalid type: Expected a dictionary but got type: {type(filters)}" + ) + if len(filters) == 1: + # The only operators allowed at the top level are $AND, $OR, and $NOT + # First check if an operator or a field + key, value = list(filters.items())[0] + if key.startswith("$"): + # Then it's an operator + if key.lower() not in ["$and", "$or", "$not"]: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + else: + # Then it's a field + return self._handle_field_filter(key, filters[key]) + + if key.lower() == "$and" or key.lower() == "$or": + if not isinstance(value, list): + raise ValueError( + f"Expected a list, but got {type(value)} for value: {value}" + ) + op = key[1:].upper() # Extract the operator + filter_clause = [self._create_filter_clause(el) for el in value] + if len(filter_clause) > 1: + return f"({f' {op} '.join(filter_clause)})" + elif len(filter_clause) == 1: + return filter_clause[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + elif key.lower() == "$not": + if isinstance(value, list): + not_conditions = [ + self._create_filter_clause(item) for item in value + ] + not_stmts = [f"NOT {condition}" for condition in not_conditions] + return f"({' AND '.join(not_stmts)})" + elif isinstance(value, dict): + not_ = self._create_filter_clause(value) + return f"(NOT {not_})" + else: + raise ValueError( + f"Invalid filter condition. Expected a dictionary " + f"or a list but got: {type(value)}" + ) + else: + raise ValueError( + f"Invalid filter condition. Expected $and, $or or $not " + f"but got: {key}" + ) + elif len(filters) > 1: + # Then all keys have to be fields (they cannot be operators) + for key in filters.keys(): + if key.startswith("$"): + raise ValueError( + f"Invalid filter condition. Expected a field but got: {key}" + ) + # These should all be fields and combined using an $and operator + and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] + if len(and_) > 1: + return f"({' AND '.join(and_)})" + elif len(and_) == 1: + return and_[0] + else: + raise ValueError( + "Invalid filter condition. Expected a dictionary " + "but got an empty dictionary" + ) + else: + return "" + + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -845,7 +1132,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -856,7 +1143,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -867,7 +1154,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( @@ -880,7 +1167,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -893,7 +1180,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: raise NotImplementedError( @@ -906,7 +1193,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: raise NotImplementedError( diff --git a/src/langchain_google_cloud_sql_pg/vectorstore.py b/src/langchain_google_cloud_sql_pg/vectorstore.py index de7275de..f5333fd6 100644 --- a/src/langchain_google_cloud_sql_pg/vectorstore.py +++ b/src/langchain_google_cloud_sql_pg/vectorstore.py @@ -15,7 +15,7 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Sequence import numpy as np from langchain_core.documents import Document @@ -551,7 +551,7 @@ async def asimilarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -563,7 +563,7 @@ def similarity_search( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by similarity search on query.""" @@ -586,7 +586,7 @@ async def asimilarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -598,7 +598,7 @@ def similarity_search_with_score( self, query: str, k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on query.""" @@ -610,7 +610,7 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -622,7 +622,7 @@ def similarity_search_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected by vector similarity search.""" @@ -634,7 +634,7 @@ async def asimilarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by vector similarity search.""" @@ -648,7 +648,7 @@ def similarity_search_with_score_by_vector( self, embedding: list[float], k: Optional[int] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected by similarity search on vector.""" @@ -664,7 +664,7 @@ async def amax_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -680,7 +680,7 @@ def max_marginal_relevance_search( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -696,7 +696,7 @@ async def amax_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -712,7 +712,7 @@ def max_marginal_relevance_search_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[Document]: """Return docs selected using the maximal marginal relevance.""" @@ -728,7 +728,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -744,7 +744,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, - filter: Optional[str] = None, + filter: Optional[dict] | Optional[str] = None, **kwargs: Any, ) -> list[tuple[Document, float]]: """Return docs and distance scores selected using the maximal marginal relevance.""" @@ -813,3 +813,11 @@ def is_valid_index( ) -> bool: """Check if index exists in the table.""" return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) + + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids)) + + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: + """Get documents by ids.""" + return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) diff --git a/src/langchain_google_cloud_sql_pg/version.py b/src/langchain_google_cloud_sql_pg/version.py index 0af59b38..69ad501b 100644 --- a/src/langchain_google_cloud_sql_pg/version.py +++ b/src/langchain_google_cloud_sql_pg/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.13.0" +__version__ = "0.14.0" diff --git a/tests/metadata_filtering_data.py b/tests/metadata_filtering_data.py new file mode 100644 index 00000000..d983e331 --- /dev/null +++ b/tests/metadata_filtering_data.py @@ -0,0 +1,263 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +METADATAS = [ + { + "name": "Wireless Headphones", + "code": "WH001", + "price": 149.99, + "is_available": True, + "release_date": "2023-10-26", + "tags": ["audio", "wireless", "electronics"], + "dimensions": [18.5, 7.2, 21.0], + "inventory_location": [101, 102], + "available_quantity": 50, + }, + { + "name": "Ergonomic Office Chair", + "code": "EC002", + "price": 299.00, + "is_available": True, + "release_date": "2023-08-15", + "tags": ["furniture", "office", "ergonomic"], + "dimensions": [65.0, 60.0, 110.0], + "inventory_location": [201], + "available_quantity": 10, + }, + { + "name": "Stainless Steel Water Bottle", + "code": "WB003", + "price": 25.50, + "is_available": False, + "release_date": "2024-01-05", + "tags": ["hydration", "eco-friendly", "kitchen"], + "dimensions": [7.5, 7.5, 25.0], + "available_quantity": 0, + }, + { + "name": "Smart Fitness Tracker", + "code": "FT004", + "price": 79.95, + "is_available": True, + "release_date": "2023-11-12", + "tags": ["fitness", "wearable", "technology"], + "dimensions": [2.0, 1.0, 25.0], + "inventory_location": [401], + "available_quantity": 100, + }, +] + +FILTERING_TEST_CASES = [ + # These tests only involve equality checks + ( + {"code": "FT004"}, + ["FT004"], + ), + # String field + ( + # check name + {"name": "Smart Fitness Tracker"}, + ["FT004"], + ), + # Boolean fields + ( + {"is_available": True}, + ["WH001", "FT004", "EC002"], + ), + # And semantics for top level filtering + ( + {"code": "WH001", "is_available": True}, + ["WH001"], + ), + # These involve equality checks and other operators + # like $ne, $gt, $gte, $lt, $lte + ( + {"available_quantity": {"$eq": 10}}, + ["EC002"], + ), + ( + {"available_quantity": {"$ne": 0}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"available_quantity": {"$gt": 60}}, + ["FT004"], + ), + ( + {"available_quantity": {"$gte": 50}}, + ["WH001", "FT004"], + ), + ( + {"available_quantity": {"$lt": 5}}, + ["WB003"], + ), + ( + {"available_quantity": {"$lte": 10}}, + ["WB003", "EC002"], + ), + # Repeat all the same tests with name (string column) + ( + {"code": {"$eq": "WH001"}}, + ["WH001"], + ), + ( + {"code": {"$ne": "WB003"}}, + ["WH001", "FT004", "EC002"], + ), + # And also gt, gte, lt, lte relying on lexicographical ordering + ( + {"name": {"$gt": "Wireless Headphones"}}, + [], + ), + ( + {"name": {"$gte": "Wireless Headphones"}}, + ["WH001"], + ), + ( + {"name": {"$lt": "Smart Fitness Tracker"}}, + ["EC002"], + ), + ( + {"name": {"$lte": "Smart Fitness Tracker"}}, + ["FT004", "EC002"], + ), + ( + {"is_available": {"$eq": True}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"is_available": {"$ne": True}}, + ["WB003"], + ), + # Test float column. + ( + {"price": {"$gt": 200.0}}, + ["EC002"], + ), + ( + {"price": {"$gte": 149.99}}, + ["WH001", "EC002"], + ), + ( + {"price": {"$lt": 50.0}}, + ["WB003"], + ), + ( + {"price": {"$lte": 79.95}}, + ["FT004", "WB003"], + ), + # These involve usage of AND, OR and NOT operators + ( + {"$or": [{"code": "WH001"}, {"code": "EC002"}]}, + ["WH001", "EC002"], + ), + ( + {"$or": [{"code": "WH001"}, {"available_quantity": 10}]}, + ["WH001", "EC002"], + ), + ( + {"$and": [{"code": "WH001"}, {"code": "EC002"}]}, + [], + ), + # Test for $not operator + ( + {"$not": {"code": "WB003"}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": [{"code": "WB003"}]}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": {"available_quantity": 0}}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": [{"available_quantity": 0}]}, + ["WH001", "FT004", "EC002"], + ), + ( + {"$not": {"is_available": True}}, + ["WB003"], + ), + ( + {"$not": [{"is_available": True}]}, + ["WB003"], + ), + ( + {"$not": {"price": {"$gt": 150.0}}}, + ["WH001", "FT004", "WB003"], + ), + ( + {"$not": [{"price": {"$gt": 150.0}}]}, + ["WH001", "FT004", "WB003"], + ), + # These involve special operators like $in, $nin, $between + # Test between + ( + {"available_quantity": {"$between": (40, 60)}}, + ["WH001"], + ), + # Test in + ( + {"name": {"$in": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, + ["FT004", "WB003"], + ), + # With numeric fields + ( + {"available_quantity": {"$in": [0, 10]}}, + ["WB003", "EC002"], + ), + # Test nin + ( + {"name": {"$nin": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, + ["WH001", "EC002"], + ), + ## with numeric fields + ( + {"available_quantity": {"$nin": [50, 0, 10]}}, + ["FT004"], + ), + # These involve special operators like $like, $ilike that + # may be specified to certain databases. + ( + {"name": {"$like": "Wireless%"}}, + ["WH001"], + ), + ( + {"name": {"$like": "%less%"}}, # adam and jane + ["WH001", "WB003"], + ), + # These involve the special operator $exists + ( + {"tags": {"$exists": False}}, + [], + ), + ( + {"inventory_location": {"$exists": False}}, + ["WB003"], + ), +] + +NEGATIVE_TEST_CASES = [ + {"$nor": [{"code": "WH001"}, {"code": "EC002"}]}, + {"$and": {"is_available": True}}, + {"is_available": {"$and": True}}, + {"is_available": {"name": "{Wireless Headphones", "code": "EC002"}}, + {"my column": {"$and": True}}, + {"is_available": {"code": "WH001", "code": "EC002"}}, + {"$and": {}}, + {"$and": []}, + {"$not": True}, +] diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py index 1cad46a5..69d36518 100644 --- a/tests/test_async_checkpoint.py +++ b/tests/test_async_checkpoint.py @@ -39,7 +39,7 @@ empty_checkpoint, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer -from langgraph.prebuilt import ( +from langgraph.prebuilt import ( # type: ignore[import-not-found] ToolNode, ValidationNode, create_react_agent, @@ -135,7 +135,6 @@ async def async_engine(): await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name}"') await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') await async_engine.close() - await async_engine._connector.close_async() @pytest_asyncio.fixture @@ -173,7 +172,7 @@ async def test_checkpoint_async( @pytest.fixture -def test_data(): +def test_data() -> dict[str, Any]: """Fixture providing test data for checkpoint tests.""" config_0: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} config_1: RunnableConfig = { @@ -224,13 +223,13 @@ def test_data(): "source": "input", "step": 2, "writes": {}, - "parents": 1, + "parents": 1, # type: ignore[typeddict-item] } metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, - "parents": None, + "parents": None, # type: ignore[typeddict-item] } metadata_3: CheckpointMetadata = {} @@ -375,13 +374,14 @@ def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: saved = await checkpointer.aget_tuple(thread_agent_config) assert saved is not None - assert saved.checkpoint["channel_values"] == { - "messages": [ - _AnyIdHumanMessage(content="hi?"), - AIMessage(content="hi?", id="0"), - ], - "agent": "agent", - } + assert ( + _AnyIdHumanMessage(content="hi?") + in saved.checkpoint["channel_values"]["messages"] + ) + assert ( + AIMessage(content="hi?", id="0") + in saved.checkpoint["channel_values"]["messages"] + ) assert saved.metadata == { "parents": {}, "source": "loop", diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index 2d808dad..418dbbad 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS from sqlalchemy import text from langchain_google_cloud_sql_pg import Column, PostgresEngine @@ -27,7 +28,9 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -37,7 +40,9 @@ docs = [ Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] - +filter_docs = [ + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) +] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -86,6 +91,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -98,7 +104,6 @@ async def vs(self, engine): embedding_service=embeddings_service, table_name=DEFAULT_TABLE, ) - ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_documents(docs, ids=ids) yield vs @@ -129,26 +134,62 @@ async def vs_custom(self, engine): await vs_custom.aadd_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_custom_filter(self, engine): + await engine._ainit_vectorstore_table( + CUSTOM_FILTER_TABLE, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + ) + + vs_custom_filter = await AsyncPostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + await vs_custom_filter.aadd_documents(filter_docs, ids=ids) + yield vs_custom_filter + async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 - assert results == [Document(page_content="foo")] + assert results == [Document(page_content="foo", id=ids[0])] results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") - assert results == [Document(page_content="bar")] + assert results == [Document(page_content="bar", id=ids[1])] async def test_asimilarity_search_score(self, vs): results = await vs.asimilarity_search_with_score("foo") assert len(results) == 4 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_asimilarity_search_by_vector(self, vs): embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 - assert results[0] == Document(page_content="foo") + assert results[0] == Document(page_content="foo", id=ids[0]) results = await vs.asimilarity_search_with_score_by_vector(embedding) - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs): @@ -171,7 +212,7 @@ async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs "foo", **score_threshold ) assert len(results) == 1 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) score_threshold = {"score_threshold": 0.02} vs.distance_strategy = DistanceStrategy.EUCLIDEAN @@ -195,78 +236,108 @@ async def test_similarity_search_with_relevance_scores_threshold_euclidean( "foo", **score_threshold ) assert len(results) == 1 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) async def test_amax_marginal_relevance_search(self, vs): results = await vs.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search( "bar", filter="content = 'boo'" ) - assert results[0] == Document(page_content="boo") + assert results[0] == Document(page_content="boo", id=ids[3]) async def test_amax_marginal_relevance_search_vector(self, vs): embedding = embeddings_service.embed_query("bar") results = await vs.amax_marginal_relevance_search_by_vector(embedding) - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) async def test_amax_marginal_relevance_search_vector_score(self, vs): embedding = embeddings_service.embed_query("bar") results = await vs.amax_marginal_relevance_search_with_score_by_vector( embedding ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search_with_score_by_vector( embedding, lambda_mult=0.75, fetch_k=10 ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) async def test_similarity_search(self, vs_custom): results = await vs_custom.asimilarity_search("foo", k=1) assert len(results) == 1 - assert results == [Document(page_content="foo")] + assert results == [Document(page_content="foo", id=ids[0])] results = await vs_custom.asimilarity_search( "foo", k=1, filter="mycontent = 'bar'" ) - assert results == [Document(page_content="bar")] + assert results == [Document(page_content="bar", id=ids[1])] async def test_similarity_search_score(self, vs_custom): results = await vs_custom.asimilarity_search_with_score("foo") assert len(results) == 4 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_similarity_search_by_vector(self, vs_custom): embedding = embeddings_service.embed_query("foo") results = await vs_custom.asimilarity_search_by_vector(embedding) assert len(results) == 4 - assert results[0] == Document(page_content="foo") + assert results[0] == Document(page_content="foo", id=ids[0]) results = await vs_custom.asimilarity_search_with_score_by_vector(embedding) - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_max_marginal_relevance_search(self, vs_custom): results = await vs_custom.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs_custom.amax_marginal_relevance_search( "bar", filter="mycontent = 'boo'" ) - assert results[0] == Document(page_content="boo") + assert results[0] == Document(page_content="boo", id=ids[3]) async def test_max_marginal_relevance_search_vector(self, vs_custom): embedding = embeddings_service.embed_query("bar") results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding) - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) async def test_max_marginal_relevance_search_vector_score(self, vs_custom): embedding = embeddings_service.embed_query("bar") results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( embedding ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector( embedding, lambda_mult=0.75, fetch_k=10 ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) + + async def test_aget_by_ids(self, vs): + test_ids = [ids[0]] + results = await vs.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + async def test_aget_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = await vs_custom.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + def test_get_by_ids(self, vs): + test_ids = [ids[0]] + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_by_ids(ids=test_ids) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_with_metadata_filters( + self, + vs_custom_filter, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + docs = await vs_custom_filter.asimilarity_search( + "meow", k=5, filter=test_filter + ) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 3a9c5bd3..0cd6f4ac 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -103,7 +103,6 @@ async def engine(): await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name}"') await aexecute(engine, f'DROP TABLE IF EXISTS "{table_name_writes}"') await engine.close() - await engine._connector.close_async() @pytest_asyncio.fixture @@ -119,7 +118,6 @@ async def async_engine(): await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_async}"') await aexecute(async_engine, f'DROP TABLE IF EXISTS "{table_name_writes_async}"') await async_engine.close() - await async_engine._connector.close_async() @pytest_asyncio.fixture @@ -197,7 +195,7 @@ def test_checkpoint_table(engine: Any) -> None: @pytest.fixture -def test_data(): +def test_data() -> dict[str, Any]: """Fixture providing test data for checkpoint tests.""" config_0: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} config_1: RunnableConfig = { @@ -248,13 +246,13 @@ def test_data(): "source": "input", "step": 2, "writes": {}, - "parents": 1, + "parents": 1, # type: ignore[typeddict-item] } metadata_2: CheckpointMetadata = { "source": "loop", "step": 1, "writes": {"foo": "bar"}, - "parents": None, + "parents": None, # type: ignore[typeddict-item] } metadata_3: CheckpointMetadata = {} diff --git a/tests/test_standard_test_suite.py b/tests/test_standard_test_suite.py new file mode 100644 index 00000000..19c77128 --- /dev/null +++ b/tests/test_standard_test_suite.py @@ -0,0 +1,159 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +import pytest_asyncio +from langchain_tests.integration_tests import VectorStoreIntegrationTests +from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE +from sqlalchemy import text + +from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore + +DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4()) +DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4()) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.asyncio +class TestStandardSuiteSync(VectorStoreIntegrationTests): + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(loop_scope="function") + async def sync_engine(self, db_project, db_region, db_instance, db_name): + sync_engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield sync_engine + await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') + await sync_engine.close() + + @pytest.fixture(scope="function") + def vectorstore(self, sync_engine): + """Get an empty vectorstore for unit tests.""" + sync_engine.init_vectorstore_table( + DEFAULT_TABLE_SYNC, + EMBEDDING_SIZE, + id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False), + ) + + vs = PostgresVectorStore.create_sync( + sync_engine, + embedding_service=self.get_embeddings(), + table_name=DEFAULT_TABLE_SYNC, + ) + yield vs + + +@pytest.mark.filterwarnings("ignore") +@pytest.mark.asyncio +class TestStandardSuiteAsync(VectorStoreIntegrationTests): + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for Cloud SQL") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for Cloud SQL") + + @pytest_asyncio.fixture(loop_scope="function") + async def async_engine(self, db_project, db_region, db_instance, db_name): + async_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield async_engine + await aexecute(async_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') + await async_engine.close() + + @pytest_asyncio.fixture(loop_scope="function") + async def vectorstore(self, async_engine): + """Get an empty vectorstore for unit tests.""" + await async_engine.ainit_vectorstore_table( + DEFAULT_TABLE, + EMBEDDING_SIZE, + id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False), + ) + + vs = await PostgresVectorStore.create( + async_engine, + embedding_service=self.get_embeddings(), + table_name=DEFAULT_TABLE, + ) + + yield vs diff --git a/tests/test_vectorstore_search.py b/tests/test_vectorstore_search.py index 3a0ca81a..ae1341ed 100644 --- a/tests/test_vectorstore_search.py +++ b/tests/test_vectorstore_search.py @@ -19,6 +19,7 @@ import pytest_asyncio from langchain_core.documents import Document from langchain_core.embeddings import DeterministicFakeEmbedding +from metadata_filtering_data import FILTERING_TEST_CASES, METADATAS, NEGATIVE_TEST_CASES from sqlalchemy import text from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore @@ -27,6 +28,10 @@ DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE = "test_table_custom_filter" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_FILTER_TABLE_SYNC = "test_table_custom_filter_sync" + str(uuid.uuid4()).replace( + "-", "_" +) VECTOR_SIZE = 768 embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -37,7 +42,9 @@ docs = [ Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) ] - +filter_docs = [ + Document(page_content=texts[i], metadata=METADATAS[i]) for i in range(len(texts)) +] embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] @@ -88,6 +95,7 @@ async def engine(self, db_project, db_region, db_instance, db_name): ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE}") await engine.close() @pytest_asyncio.fixture(scope="class") @@ -100,7 +108,6 @@ async def vs(self, engine): embedding_service=embeddings_service, table_name=DEFAULT_TABLE, ) - ids = [str(uuid.uuid4()) for i in range(len(texts))] await vs.aadd_documents(docs, ids=ids) yield vs @@ -143,26 +150,63 @@ async def vs_custom(self, engine_sync): vs_custom.add_documents(docs, ids=ids) yield vs_custom + @pytest_asyncio.fixture(scope="class") + async def vs_custom_filter(self, engine): + await engine.ainit_vectorstore_table( + CUSTOM_FILTER_TABLE, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + overwrite_existing=True, + ) + + vs_custom_filter = await PostgresVectorStore.create( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + await vs_custom_filter.aadd_documents(filter_docs, ids=ids) + yield vs_custom_filter + async def test_asimilarity_search(self, vs): results = await vs.asimilarity_search("foo", k=1) assert len(results) == 1 - assert results == [Document(page_content="foo")] + assert results == [Document(page_content="foo", id=ids[0])] results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'") - assert results == [Document(page_content="bar")] + assert results == [Document(page_content="bar", id=ids[1])] async def test_asimilarity_search_score(self, vs): results = await vs.asimilarity_search_with_score("foo") assert len(results) == 4 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_asimilarity_search_by_vector(self, vs): embedding = embeddings_service.embed_query("foo") results = await vs.asimilarity_search_by_vector(embedding) assert len(results) == 4 - assert results[0] == Document(page_content="foo") + assert results[0] == Document(page_content="foo", id=ids[0]) results = await vs.asimilarity_search_with_score_by_vector(embedding) - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs): @@ -185,7 +229,7 @@ async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs "foo", **score_threshold ) assert len(results) == 1 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) async def test_similarity_search_with_relevance_scores_threshold_euclidean( self, engine @@ -202,32 +246,57 @@ async def test_similarity_search_with_relevance_scores_threshold_euclidean( "foo", **score_threshold ) assert len(results) == 1 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) async def test_amax_marginal_relevance_search(self, vs): results = await vs.amax_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search( "bar", filter="content = 'boo'" ) - assert results[0] == Document(page_content="boo") + assert results[0] == Document(page_content="boo", id=ids[3]) async def test_amax_marginal_relevance_search_vector(self, vs): embedding = embeddings_service.embed_query("bar") results = await vs.amax_marginal_relevance_search_by_vector(embedding) - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) async def test_amax_marginal_relevance_search_vector_score(self, vs): embedding = embeddings_service.embed_query("bar") results = await vs.amax_marginal_relevance_search_with_score_by_vector( embedding ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) results = await vs.amax_marginal_relevance_search_with_score_by_vector( embedding, lambda_mult=0.75, fetch_k=10 ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) + + async def test_aget_by_ids(self, vs): + test_ids = [ids[0]] + results = await vs.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + async def test_aget_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = await vs_custom.aget_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + async def test_vectorstore_with_metadata_filters( + self, + vs_custom_filter, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + docs = await vs_custom_filter.asimilarity_search( + "meow", k=5, filter=test_filter + ) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter class TestVectorStoreSearchSync: @@ -257,6 +326,7 @@ async def engine_sync(self, db_project, db_region, db_instance, db_name): ) yield engine await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_SYNC}") + await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_FILTER_TABLE_SYNC}") await engine.close() @pytest.fixture(scope="class") @@ -286,49 +356,112 @@ def vs_custom(self, engine_sync): vs_custom.add_documents(docs, ids=ids) yield vs_custom + @pytest.fixture(scope="class") + def vs_custom_filter_sync(self, engine_sync): + engine_sync.init_vectorstore_table( + CUSTOM_FILTER_TABLE_SYNC, + VECTOR_SIZE, + metadata_columns=[ + Column("name", "TEXT"), + Column("code", "TEXT"), + Column("price", "FLOAT"), + Column("is_available", "BOOLEAN"), + Column("tags", "TEXT[]"), + Column("inventory_location", "INTEGER[]"), + Column("available_quantity", "INTEGER", nullable=True), + ], + id_column="langchain_id", + store_metadata=False, + overwrite_existing=True, + ) + + vs_custom_filter_sync = PostgresVectorStore.create_sync( + engine_sync, + embedding_service=embeddings_service, + table_name=CUSTOM_FILTER_TABLE_SYNC, + metadata_columns=[ + "name", + "code", + "price", + "is_available", + "tags", + "inventory_location", + "available_quantity", + ], + id_column="langchain_id", + ) + + vs_custom_filter_sync.add_documents(filter_docs, ids=ids) + yield vs_custom_filter_sync + def test_similarity_search(self, vs_custom): results = vs_custom.similarity_search("foo", k=1) assert len(results) == 1 - assert results == [Document(page_content="foo")] + assert results == [Document(page_content="foo", id=ids[0])] results = vs_custom.similarity_search("foo", k=1, filter="mycontent = 'bar'") - assert results == [Document(page_content="bar")] + assert results == [Document(page_content="bar", id=ids[1])] def test_similarity_search_score(self, vs_custom): results = vs_custom.similarity_search_with_score("foo") assert len(results) == 4 - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 def test_similarity_search_by_vector(self, vs_custom): embedding = embeddings_service.embed_query("foo") results = vs_custom.similarity_search_by_vector(embedding) assert len(results) == 4 - assert results[0] == Document(page_content="foo") + assert results[0] == Document(page_content="foo", id=ids[0]) results = vs_custom.similarity_search_with_score_by_vector(embedding) - assert results[0][0] == Document(page_content="foo") + assert results[0][0] == Document(page_content="foo", id=ids[0]) assert results[0][1] == 0 def test_max_marginal_relevance_search(self, vs_custom): results = vs_custom.max_marginal_relevance_search("bar") - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) results = vs_custom.max_marginal_relevance_search( "bar", filter="mycontent = 'boo'" ) - assert results[0] == Document(page_content="boo") + assert results[0] == Document(page_content="boo", id=ids[3]) def test_max_marginal_relevance_search_vector(self, vs_custom): embedding = embeddings_service.embed_query("bar") results = vs_custom.max_marginal_relevance_search_by_vector(embedding) - assert results[0] == Document(page_content="bar") + assert results[0] == Document(page_content="bar", id=ids[1]) def test_max_marginal_relevance_search_vector_score(self, vs_custom): embedding = embeddings_service.embed_query("bar") results = vs_custom.max_marginal_relevance_search_with_score_by_vector( embedding ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) results = vs_custom.max_marginal_relevance_search_with_score_by_vector( embedding, lambda_mult=0.75, fetch_k=10 ) - assert results[0][0] == Document(page_content="bar") + assert results[0][0] == Document(page_content="bar", id=ids[1]) + + def test_get_by_ids_custom_vs(self, vs_custom): + test_ids = [ids[0]] + results = vs_custom.get_by_ids(ids=test_ids) + + assert results[0] == Document(page_content="foo", id=ids[0]) + + @pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES) + def test_sync_vectorstore_with_metadata_filters( + self, + vs_custom_filter_sync, + test_filter, + expected_ids, + ): + """Test end to end construction and search.""" + + docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES) + def test_metadata_filter_negative_tests(self, vs_custom_filter_sync, test_filter): + with pytest.raises((ValueError, NotImplementedError)): + docs = vs_custom_filter_sync.similarity_search( + "meow", k=5, filter=test_filter + )