diff --git a/src/llama_index_cloud_sql_pg/__init__.py b/src/llama_index_cloud_sql_pg/__init__.py index e827164..2916607 100644 --- a/src/llama_index_cloud_sql_pg/__init__.py +++ b/src/llama_index_cloud_sql_pg/__init__.py @@ -15,6 +15,7 @@ from .document_store import PostgresDocumentStore from .engine import Column, PostgresEngine from .index_store import PostgresIndexStore +from .vector_store import PostgresVectorStore from .version import __version__ _all = [ @@ -22,5 +23,6 @@ "PostgresEngine", "PostgresDocumentStore", "PostgresIndexStore", + "PostgresVectorStore", "__version__", ] diff --git a/src/llama_index_cloud_sql_pg/async_vector_store.py b/src/llama_index_cloud_sql_pg/async_vector_store.py new file mode 100644 index 0000000..82b9857 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/async_vector_store.py @@ -0,0 +1,675 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import base64 +import json +import re +import uuid +import warnings +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type + +import numpy as np +from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + FilterCondition, + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import PostgresEngine +from .indexes import ( + DEFAULT_DISTANCE_STRATEGY, + DEFAULT_INDEX_NAME_SUFFIX, + BaseIndex, + DistanceStrategy, + ExactNearestNeighbor, + QueryOptions, +) + + +class AsyncPostgresVectorStore(BasePydanticVectorStore): + """Google Cloud SQL Vector Store class""" + + stores_text: bool = True + is_embedding_query: bool = True + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node_data", + stores_text: bool = True, + is_embedding_query: bool = True, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + index_query_options: Optional[QueryOptions] = None, + ): + """AsyncPostgresVectorStore constructor. + Args: + key (object): Prevent direct constructor usage. + engine (AsyncEngine): Connection pool engine for managing connections to Cloud SQL database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node_data". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + index_query_options (QueryOptions): Index query option. + + + Raises: + Exception: If called directly by user. + """ + if key != AsyncPostgresVectorStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query) + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + self._id_column = id_column + self._text_column = text_column + self._embedding_column = embedding_column + self._metadata_json_column = metadata_json_column + self._metadata_columns = metadata_columns + self._ref_doc_id_column = ref_doc_id_column + self._node_column = node_column + self._distance_strategy = distance_strategy + self._index_query_options = index_query_options + + @classmethod + async def create( + cls: Type[AsyncPostgresVectorStore], + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node_data", + stores_text: bool = True, + is_embedding_query: bool = True, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + index_query_options: Optional[QueryOptions] = None, + ) -> AsyncPostgresVectorStore: + """Create an AsyncPostgresVectorStore instance and validates the table schema. + + Args: + engine (PostgresEngine): PostgresEngine Engine for managing connections to Cloud SQL PG database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node_data". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + index_query_options (QueryOptions): Index query option. + + Raises: + Exception: If table does not exist or follow the provided structure. + + Returns: + AsyncPostgresVectorStore + """ + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" + async with engine._pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + columns = {} + for field in results: + columns[field["column_name"]] = field["data_type"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if text_column not in columns: + raise ValueError(f"Text column, {text_column}, does not exist.") + text_type = columns[text_column] + if text_type != "text" and "char" not in text_type: + raise ValueError( + f"Text column, {text_column}, is type, {text_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + if columns[embedding_column] != "USER-DEFINED": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector." + ) + if node_column not in columns: + raise ValueError(f"Node column, {node_column}, does not exist.") + if columns[node_column] != "json": + raise ValueError(f"Node column, {node_column}, is not type JSON.") + if ref_doc_id_column not in columns: + raise ValueError( + f"Reference Document Id column, {ref_doc_id_column}, does not exist." + ) + if metadata_json_column not in columns: + raise ValueError( + f"Metadata column, {metadata_json_column}, does not exist." + ) + if columns[metadata_json_column] != "jsonb": + raise ValueError( + f"Metadata column, {metadata_json_column}, is not type JSONB." + ) + # If using metadata_columns check to make sure column exists + for column in metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") + + return cls( + cls.__create_key, + engine._pool, + table_name, + schema_name=schema_name, + id_column=id_column, + text_column=text_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + metadata_columns=metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + stores_text=stores_text, + is_embedding_query=is_embedding_query, + distance_strategy=distance_strategy, + index_query_options=index_query_options, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncPostgresVectorStore" + + @property + def client(self) -> Any: + """Get client.""" + return self._engine + + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + """Asynchronously add nodes to the table.""" + ids = [] + metadata_col_names = ( + ", " + ", ".join(self._metadata_columns) + if len(self._metadata_columns) > 0 + else "" + ) + metadata_col_values = ( + ", :" + ", :".join(self._metadata_columns) + if len(self._metadata_columns) > 0 + else "" + ) + insert_stmt = f"""INSERT INTO "{self._schema_name}"."{self._table_name}"( + {self._id_column}, + {self._text_column}, + {self._embedding_column}, + {self._metadata_json_column}, + {self._ref_doc_id_column}, + {self._node_column} + {metadata_col_names} + ) VALUES (:node_id, :text, :embedding, :li_metadata, :ref_doc_id, :node_data {metadata_col_values}) + """ + node_values_list = [] + for node in nodes: + metadata = json.dumps( + node_to_metadata_dict( + node, remove_text=self.stores_text, flat_metadata=False + ) + ) + node_values = { + "node_id": node.node_id, + "text": node.get_content(metadata_mode=MetadataMode.NONE), + "embedding": str(node.get_embedding()), + "li_metadata": metadata, + "ref_doc_id": node.ref_doc_id, + "node_data": node.to_json(), + } + for metadata_column in self._metadata_columns: + if metadata_column in node.metadata: + node_values[metadata_column] = node.metadata.get(metadata_column) + else: + node_values[metadata_column] = None + node_values_list.append(node_values) + ids.append(node.node_id) + async with self._engine.connect() as conn: + await conn.execute(text(insert_stmt), node_values_list) + await conn.commit() + return ids + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Asynchronously delete nodes belonging to provided parent document from the table.""" + query = f"""DELETE FROM "{self._schema_name}"."{self._table_name}" WHERE {self._ref_doc_id_column} = '{ref_doc_id}'""" + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def adelete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" + if not node_ids and not filters: + return + all_filters: List[MetadataFilter | MetadataFilters] = [] + if node_ids: + all_filters.append( + MetadataFilter( + key=self._id_column, value=node_ids, operator=FilterOperator.IN + ) + ) + if filters: + all_filters.append(filters) + filters_stmt = "" + if all_filters: + all_metadata_filters = MetadataFilters( + filters=all_filters, condition=FilterCondition.AND + ) + filters_stmt = self.__parse_metadata_filters_recursively( + all_metadata_filters + ) + filters_stmt = f"WHERE {filters_stmt}" if filters_stmt else "" + query = f'DELETE FROM "{self._schema_name}"."{self._table_name}" {filters_stmt}' + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def aclear(self) -> None: + """Asynchronously delete all nodes from the table.""" + query = f'TRUNCATE TABLE "{self._schema_name}"."{self._table_name}"' + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def aget_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + """Asynchronously get nodes from the table matching the provided nodes and filters.""" + query = VectorStoreQuery( + node_ids=node_ids, filters=filters, similarity_top_k=-1 + ) + result = await self.aquery(query) + return list(result.nodes) if result.nodes else [] + + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + """Asynchronously query vector store.""" + results = await self.__query_columns(query) + nodes = [] + ids = [] + similarities = [] + + for row in results: + node = metadata_dict_to_node( + row[self._metadata_json_column], row[self._text_column] + ) + if row[self._ref_doc_id_column]: + node_source = TextNode(id_=row[self._ref_doc_id_column]) + node.relationships[NodeRelationship.SOURCE] = ( + node_source.as_related_node_info() + ) + nodes.append(node) + ids.append(row[self._id_column]) + if "distance" in row: + similarities.append(row["distance"]) + return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) + + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def delete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def get_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + raise NotImplementedError( + "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + ) + + async def set_maintenance_work_mem(self, num_leaves: int, vector_size: int) -> None: + """Set database maintenance work memory (for index creation).""" + # Required index memory in MB + buffer = 1 + index_memory_required = ( + round(50 * num_leaves * vector_size * 4 / 1024 / 1024) + buffer + ) # Convert bytes to MB + query = f"SET maintenance_work_mem TO '{index_memory_required} MB';" + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def aapply_vector_index( + self, + index: BaseIndex, + name: Optional[str] = None, + concurrently: bool = False, + ) -> None: + """Create index in the vector store table.""" + if isinstance(index, ExactNearestNeighbor): + await self.adrop_vector_index() + return + + function = index.distance_strategy.index_function + filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else "" + params = "WITH " + index.index_options() + if name is None: + if index.name == None: + index.name = self._table_name + DEFAULT_INDEX_NAME_SUFFIX + name = index.name + stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON \"{self._schema_name}\".\"{self._table_name}\" USING {index.index_type} ({self._embedding_column} {function}) {params} {filter};" + if concurrently: + async with self._engine.connect() as conn: + await conn.execute(text("COMMIT")) + await conn.execute(text(stmt)) + else: + async with self._engine.connect() as conn: + await conn.execute(text(stmt)) + await conn.commit() + + async def areindex(self, index_name: Optional[str] = None) -> None: + """Re-index the vector store table.""" + index_name = index_name or self._table_name + DEFAULT_INDEX_NAME_SUFFIX + query = f"REINDEX INDEX {index_name};" + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def adrop_vector_index( + self, + index_name: Optional[str] = None, + ) -> None: + """Drop the vector index.""" + index_name = index_name or self._table_name + DEFAULT_INDEX_NAME_SUFFIX + query = f"DROP INDEX IF EXISTS {index_name};" + async with self._engine.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + async def is_valid_index( + self, + index_name: Optional[str] = None, + ) -> bool: + """Check if index exists in the table.""" + index_name = index_name or self._table_name + DEFAULT_INDEX_NAME_SUFFIX + query = f""" + SELECT tablename, indexname + FROM pg_indexes + WHERE tablename = '{self._table_name}' AND schemaname = '{self._schema_name}' AND indexname = '{index_name}'; + """ + async with self._engine.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + results = result_map.fetchall() + return bool(len(results) == 1) + + async def __query_columns( + self, + query: VectorStoreQuery, + **kwargs: Any, + ) -> Sequence[RowMapping]: + """Perform search query on database.""" + filters: List[MetadataFilter | MetadataFilters] = [] + if query.doc_ids: + filters.append( + MetadataFilter( + key=self._ref_doc_id_column, + value=query.doc_ids, + operator=FilterOperator.IN, + ) + ) + if query.node_ids: + filters.append( + MetadataFilter( + key=self._id_column, + value=query.node_ids, + operator=FilterOperator.IN, + ) + ) + if query.filters: + filters.append(query.filters) + + # Note: + # Hybrid search is not yet supported, so following fields in `query` are ignored: + # query_str, mode, alpha, mmr_threshold, sparse_top_k, hybrid_top_k + # Vectors are already stored `self._embedding_column` so a custom embedding_field is ignored. + query_filters = MetadataFilters(filters=filters, condition=FilterCondition.AND) + + filters_stmt = self.__parse_metadata_filters_recursively(query_filters) + filters_stmt = f"WHERE {filters_stmt}" if filters_stmt else "" + operator = self._distance_strategy.operator + search_function = self._distance_strategy.search_function + + # query_embedding is used for scoring + scoring_stmt = ( + f", {search_function}({self._embedding_column}, '{query.query_embedding}') as distance" + if query.query_embedding + else "" + ) + + # results are sorted on ORDER BY query_embedding + order_stmt = ( + f" ORDER BY {self._embedding_column} {operator} '{query.query_embedding}' " + if query.query_embedding + else "" + ) + + # similarity_top_k is used for limiting number of retrieved nodes + limit_stmt = ( + f" LIMIT {query.similarity_top_k} " if query.similarity_top_k >= 1 else "" + ) + + query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}' + async with self._engine.connect() as conn: + if self._index_query_options: + query_options_stmt = ( + f"SET LOCAL {self._index_query_options.to_string()};" + ) + await conn.execute(text(query_options_stmt)) + result = await conn.execute(text(query_stmt)) + result_map = result.mappings() + results = result_map.fetchall() + return results + + def __parse_metadata_filters_recursively( + self, metadata_filters: MetadataFilters + ) -> str: + """ + Parses a MetadataFilters object into a SQL WHERE clause. + Supports a mixed list of MetadataFilter and nested MetadataFilters. + """ + if not metadata_filters.filters: + return "" + + where_clauses = [] + for filter_item in metadata_filters.filters: + if isinstance(filter_item, MetadataFilter): + clause = self.__parse_metadata_filter(filter_item) + if clause: + where_clauses.append(clause) + elif isinstance(filter_item, MetadataFilters): + # Handle nested filters recursively + nested_clause = self.__parse_metadata_filters_recursively(filter_item) + if nested_clause: + where_clauses.append(f"({nested_clause})") + + # Combine clauses with the specified condition + condition_value = ( + metadata_filters.condition.value + if metadata_filters.condition + else FilterCondition.AND.value + ) + return f" {condition_value} ".join(where_clauses) if where_clauses else "" + + def __parse_metadata_filter(self, filter: MetadataFilter) -> str: + key = self.__to_postgres_key(filter.key) + op = self.__to_postgres_operator(filter.operator) + if filter.operator == FilterOperator.IS_EMPTY: + # checks for emptiness of a field, so value is ignored + # cast to jsonb to check array length + return f"((({key})::jsonb IS NULL) OR (jsonb_array_length(({key})::jsonb) = 0))" + if filter.operator == FilterOperator.CONTAINS: + # Expects a list stored in the metadata, and a single value to compare + if isinstance(filter.value, list): + # skip improperly provided filter and raise a warning + warnings.warn( + f"""Expecting a scalar in the filter value, but got {type(filter.value)}. + Ignoring this filter: + Key -> '{filter.key}' + Operator -> '{filter.operator}' + Value -> '{filter.value}'""" + ) + return "" + return f"({key})::jsonb {op} '[\"{filter.value}\"]' " + if filter.operator == FilterOperator.TEXT_MATCH: + return f"{key} {op} '%{filter.value}%' " + if filter.operator in [ + FilterOperator.ANY, + FilterOperator.ALL, + FilterOperator.IN, + FilterOperator.NIN, + ]: + # Expect a single value in metadata and a list to compare + if not isinstance(filter.value, list): + # skip improperly provided filter and raise a warning + warnings.warn( + f"""Expecting List in the filter value, but got {type(filter.value)}. + Ignoring this filter: + Key -> '{filter.key}' + Operator -> '{filter.operator}' + Value -> '{filter.value}'""" + ) + return "" + filter_value = ", ".join(f"'{e}'" for e in filter.value) + if filter.operator in [FilterOperator.ANY, FilterOperator.ALL]: + return f"({key})::jsonb {op} (ARRAY[{filter_value}])" + else: + return f"{key} {op} ({filter_value})" + + # Check if value is a number. If so, cast the metadata value to a float + # This is necessary because the metadata is stored as a string. + if isinstance(filter.value, (int, float, str)): + try: + return f"{key}::float {op} {float(filter.value)}" + except ValueError: + # If not a number, then treat it as a string + pass + return f"{key} {op} '{filter.value}'" + + def __to_postgres_operator(self, operator: FilterOperator) -> str: + if operator == FilterOperator.EQ: + return "=" + elif operator == FilterOperator.GT: + return ">" + elif operator == FilterOperator.LT: + return "<" + elif operator == FilterOperator.NE: + return "!=" + elif operator == FilterOperator.GTE: + return ">=" + elif operator == FilterOperator.LTE: + return "<=" + elif operator == FilterOperator.IN: + return "IN" + elif operator == FilterOperator.NIN: + return "NOT IN" + elif operator == FilterOperator.ANY: + return "?|" + elif operator == FilterOperator.ALL: + return "?&" + elif operator == FilterOperator.CONTAINS: + return "@>" + elif operator == FilterOperator.TEXT_MATCH: + return "LIKE" + elif operator == FilterOperator.IS_EMPTY: + return "IS_EMPTY" + else: + warnings.warn(f"Unknown operator: {operator}, fallback to '='") + return "=" + + def __to_postgres_key(self, key: str) -> str: + if key in [ + *self._metadata_columns, + self._id_column, + self._ref_doc_id_column, + self._text_column, + ]: + return key + return f"{self._metadata_json_column}->>'{key}'" diff --git a/src/llama_index_cloud_sql_pg/indexes.py b/src/llama_index_cloud_sql_pg/indexes.py new file mode 100644 index 0000000..b4a121e --- /dev/null +++ b/src/llama_index_cloud_sql_pg/indexes.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class StrategyMixin: + operator: str + search_function: str + index_function: str + + +class DistanceStrategy(StrategyMixin, enum.Enum): + """Enumerator of the Distance strategies.""" + + EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops" + COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops" + INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops" + + +DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE +DEFAULT_INDEX_NAME_SUFFIX: str = "li_vectorindex" + + +@dataclass +class BaseIndex(ABC): + name: Optional[str] = None + index_type: str = "base" + distance_strategy: DistanceStrategy = field( + default_factory=lambda: DistanceStrategy.COSINE_DISTANCE + ) + partial_indexes: Optional[List[str]] = None + + @abstractmethod + def index_options(self) -> str: + """Set index query options for vector store initialization.""" + raise NotImplementedError( + "index_options method must be implemented by subclass" + ) + + +@dataclass +class ExactNearestNeighbor(BaseIndex): + index_type: str = "exactnearestneighbor" + + +@dataclass +class QueryOptions(ABC): + @abstractmethod + def to_string(self) -> str: + """Convert index attributes to string.""" + raise NotImplementedError("to_string method must be implemented by subclass") + + +@dataclass +class HNSWIndex(BaseIndex): + index_type: str = "hnsw" + m: int = 16 + ef_construction: int = 64 + + def index_options(self) -> str: + """Set index query options for vector store initialization.""" + return f"(m = {self.m}, ef_construction = {self.ef_construction})" + + +@dataclass +class HNSWQueryOptions(QueryOptions): + ef_search: int = 40 + + def to_string(self) -> str: + """Convert index attributes to string.""" + return f"hnsw.ef_search = {self.ef_search}" + + +@dataclass +class IVFFlatIndex(BaseIndex): + index_type: str = "ivfflat" + lists: int = 100 + + def index_options(self) -> str: + """Set index query options for vector store initialization.""" + return f"(lists = {self.lists})" + + +@dataclass +class IVFFlatQueryOptions(QueryOptions): + probes: int = 1 + + def to_string(self) -> str: + """Convert index attributes to string.""" + return f"ivfflat.probes = {self.probes}" + + +@dataclass +class IVFIndex(BaseIndex): + index_type: str = "ivf" + lists: int = 100 + quantizer: str = field( + default="sq8", init=False + ) # Disable `quantizer` initialization currently only supports the value "sq8" + + def index_options(self) -> str: + """Set index query options for vector store initialization.""" + return f"(lists = {self.lists}, quantizer = {self.quantizer})" + + +@dataclass +class IVFQueryOptions(QueryOptions): + probes: int = 1 + + def to_string(self) -> str: + """Convert index attributes to string.""" + return f"ivf.probes = {self.probes}" diff --git a/src/llama_index_cloud_sql_pg/vector_store.py b/src/llama_index_cloud_sql_pg/vector_store.py new file mode 100644 index 0000000..3561760 --- /dev/null +++ b/src/llama_index_cloud_sql_pg/vector_store.py @@ -0,0 +1,359 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, List, Optional, Sequence, Type + +from llama_index.core.schema import BaseNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, +) + +from .async_vector_store import AsyncPostgresVectorStore +from .engine import PostgresEngine +from .indexes import ( + DEFAULT_DISTANCE_STRATEGY, + BaseIndex, + DistanceStrategy, + QueryOptions, +) + + +class PostgresVectorStore(BasePydanticVectorStore): + """Google Cloud SQL Vector Store class""" + + __create_key = object() + + def __init__( + self, + key: object, + engine: PostgresEngine, + vs: AsyncPostgresVectorStore, + stores_text: bool = True, + is_embedding_query: bool = True, + ): + """PostgresVectorStore constructor. + Args: + key (object): Prevent direct constructor usage. + engine (PostgresEngine): Connection pool engine for managing connections to Postgres database. + vs (AsyncPostgresVectorStore): The async only Vector Store implementation + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + + Raises: + Exception: If called directly by user. + """ + if key != PostgresVectorStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + # Delegate to Pydantic's __init__ + super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query) + + self._engine = engine + self.__vs = vs + + @classmethod + async def create( + cls: Type[PostgresVectorStore], + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node_data", + stores_text: bool = True, + is_embedding_query: bool = True, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + index_query_options: Optional[QueryOptions] = None, + ) -> PostgresVectorStore: + """Create an PostgresVectorStore instance and validates the table schema. + + Args: + engine (PostgresEngine): Postgres Engine for managing connections to postgres database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node_data". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + index_query_options (QueryOptions): Index query option. + + Raises: + Exception: If table does not exist or follow the provided structure. + + Returns: + PostgresVectorStore + """ + coro = AsyncPostgresVectorStore.create( + engine, + table_name, + schema_name=schema_name, + id_column=id_column, + text_column=text_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + metadata_columns=metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + stores_text=stores_text, + is_embedding_query=is_embedding_query, + distance_strategy=distance_strategy, + index_query_options=index_query_options, + ) + vs = await engine._run_as_async(coro) + return cls( + cls.__create_key, + engine, + vs, + stores_text, + is_embedding_query, + ) + + @classmethod + def create_sync( + cls: Type[PostgresVectorStore], + engine: PostgresEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node_data", + stores_text: bool = True, + is_embedding_query: bool = True, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + index_query_options: Optional[QueryOptions] = None, + ) -> PostgresVectorStore: + """Create an PostgresVectorStore instance and validates the table schema. + + Args: + engine (PostgresEngine): Postgres Engine for managing connections to postgres database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node_data". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + distance_strategy (DistanceStrategy): Distance strategy to use for vector similarity search. Defaults to COSINE_DISTANCE. + index_query_options (QueryOptions): Index query option. + + Raises: + Exception: If table does not exist or follow the provided structure. + + Returns: + PostgresVectorStore + """ + coro = AsyncPostgresVectorStore.create( + engine, + table_name, + schema_name=schema_name, + id_column=id_column, + text_column=text_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + metadata_columns=metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + stores_text=stores_text, + is_embedding_query=is_embedding_query, + distance_strategy=distance_strategy, + index_query_options=index_query_options, + ) + vs = engine._run_as_sync(coro) + return cls( + cls.__create_key, + engine, + vs, + stores_text, + is_embedding_query, + ) + + @classmethod + def class_name(cls) -> str: + return "PostgresVectorStore" + + @property + def client(self) -> Any: + """Get client.""" + return self._engine + + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + """Asynchronously add nodes to the table.""" + return await self._engine._run_as_async(self.__vs.async_add(nodes, **kwargs)) + + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + """Synchronously add nodes to the table.""" + return self._engine._run_as_sync(self.__vs.async_add(nodes, **add_kwargs)) + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Asynchronously delete nodes belonging to provided parent document from the table.""" + await self._engine._run_as_async(self.__vs.adelete(ref_doc_id, **delete_kwargs)) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Synchronously delete nodes belonging to provided parent document from the table.""" + self._engine._run_as_sync(self.__vs.adelete(ref_doc_id, **delete_kwargs)) + + async def adelete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" + await self._engine._run_as_async( + self.__vs.adelete_nodes(node_ids, filters, **delete_kwargs) + ) + + def delete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Synchronously delete a set of nodes from the table matching the provided nodes and filters.""" + self._engine._run_as_sync( + self.__vs.adelete_nodes(node_ids, filters, **delete_kwargs) + ) + + async def aclear(self) -> None: + """Asynchronously delete all nodes from the table.""" + await self._engine._run_as_async(self.__vs.aclear()) + + def clear(self) -> None: + """Synchronously delete all nodes from the table.""" + return self._engine._run_as_sync(self.__vs.aclear()) + + async def aget_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + """Asynchronously get nodes from the table matching the provided nodes and filters.""" + return await self._engine._run_as_async(self.__vs.aget_nodes(node_ids, filters)) + + def get_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + """Asynchronously get nodes from the table matching the provided nodes and filters.""" + return self._engine._run_as_sync(self.__vs.aget_nodes(node_ids, filters)) + + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + """Asynchronously query vector store.""" + return await self._engine._run_as_async(self.__vs.aquery(query, **kwargs)) + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + """Synchronously query vector store.""" + return self._engine._run_as_sync(self.__vs.aquery(query, **kwargs)) + + async def aset_maintenance_work_mem( + self, num_leaves: int, vector_size: int + ) -> None: + """Set database maintenance work memory (for index creation).""" + await self._engine._run_as_async( + self.__vs.set_maintenance_work_mem(num_leaves, vector_size) + ) + + def set_maintenance_work_mem(self, num_leaves: int, vector_size: int) -> None: + """Set database maintenance work memory (for index creation).""" + self._engine._run_as_sync( + self.__vs.set_maintenance_work_mem(num_leaves, vector_size) + ) + + async def aapply_vector_index( + self, + index: BaseIndex, + name: Optional[str] = None, + concurrently: bool = False, + ) -> None: + """Create an index on the vector store table.""" + return await self._engine._run_as_async( + self.__vs.aapply_vector_index(index, name, concurrently) + ) + + def apply_vector_index( + self, + index: BaseIndex, + name: Optional[str] = None, + concurrently: bool = False, + ) -> None: + """Create an index on the vector store table.""" + return self._engine._run_as_sync( + self.__vs.aapply_vector_index(index, name, concurrently) + ) + + async def areindex(self, index_name: Optional[str] = None) -> None: + """Re-index the vector store table.""" + return await self._engine._run_as_async(self.__vs.areindex(index_name)) + + def reindex(self, index_name: Optional[str] = None) -> None: + """Re-index the vector store table.""" + return self._engine._run_as_sync(self.__vs.areindex(index_name)) + + async def adrop_vector_index( + self, + index_name: Optional[str] = None, + ) -> None: + """Drop the vector index.""" + return await self._engine._run_as_async( + self.__vs.adrop_vector_index(index_name) + ) + + def drop_vector_index( + self, + index_name: Optional[str] = None, + ) -> None: + """Drop the vector index.""" + return self._engine._run_as_sync(self.__vs.adrop_vector_index(index_name)) + + async def ais_valid_index( + self, + index_name: Optional[str] = None, + ) -> bool: + """Check if index exists in the table.""" + return await self._engine._run_as_async(self.__vs.is_valid_index(index_name)) + + def is_valid_index( + self, + index_name: Optional[str] = None, + ) -> bool: + """Check if index exists in the table.""" + return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) diff --git a/tests/test_async_vector_store.py b/tests/test_async_vector_store.py new file mode 100644 index 0000000..45b8b6a --- /dev/null +++ b/tests/test_async_vector_store.py @@ -0,0 +1,348 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.vector_stores.types import ( + FilterCondition, + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, +) +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from llama_index_cloud_sql_pg import Column, PostgresEngine +from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +DEFAULT_TABLE_CUSTOM_VS = "test_table" + str(uuid.uuid4()) +VECTOR_SIZE = 768 + +texts = ["foo", "bar", "baz", "foobar"] +embedding = [1.0] * VECTOR_SIZE +nodes = [ + TextNode( + id_=str(uuid.uuid4()), + text=texts[i], + embedding=[1 / (i + 1.0)] * VECTOR_SIZE, + ) + for i in range(len(texts)) +] +# setting each node as their own parent +for node in nodes: + node.relationships[NodeRelationship.SOURCE] = node.as_related_node_info() +sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead." + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield engine + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE_CUSTOM_VS}"') + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = await AsyncPostgresVectorStore.create(engine, table_name=DEFAULT_TABLE) + yield vs + + @pytest_asyncio.fixture(scope="class") + async def custom_vs(self, engine): + await engine._ainit_vector_store_table( + DEFAULT_TABLE_CUSTOM_VS, + VECTOR_SIZE, + overwrite_existing=True, + metadata_columns=[ + Column(name="len", data_type="INTEGER", nullable=False), + Column( + name="nullable_int_field", + data_type="INTEGER", + nullable=True, + ), + Column( + name="nullable_str_field", + data_type="VARCHAR", + nullable=True, + ), + ], + ) + vs = await AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE_CUSTOM_VS, + metadata_columns=[ + "len", + "nullable_int_field", + "nullable_str_field", + ], + ) + yield vs + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + AsyncPostgresVectorStore(engine, table_name=DEFAULT_TABLE) + + async def test_validate_id_column_create(self, engine, vs): + test_id_column = "test_id_column" + with pytest.raises( + Exception, match=f"Id column, {test_id_column}, does not exist." + ): + await AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ) + + async def test_validate_text_column_create(self, engine, vs): + test_text_column = "test_text_column" + with pytest.raises( + Exception, match=f"Text column, {test_text_column}, does not exist." + ): + await AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ) + + async def test_validate_embedding_column_create(self, engine, vs): + test_embed_column = "test_embed_column" + with pytest.raises( + Exception, + match=f"Embedding column, {test_embed_column}, does not exist.", + ): + await AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + embedding_column=test_embed_column, + ) + + async def test_validate_node_column_create(self, engine, vs): + test_node_column = "test_node_column" + with pytest.raises( + Exception, match=f"Node column, {test_node_column}, does not exist." + ): + await AsyncPostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ) + + async def test_validate_ref_doc_id_column_create(self, engine, vs): + test_ref_doc_id_column = "test_ref_doc_id_column" + with pytest.raises( + Exception, + match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", + ): + await AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ) + + async def test_validate_metadata_json_column_create(self, engine, vs): + test_metadata_json_column = "test_metadata_json_column" + with pytest.raises( + Exception, + match=f"Metadata column, {test_metadata_json_column}, does not exist.", + ): + await AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, + ) + + async def test_async_add(self, engine, vs): + await vs.async_add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 4 + + async def test_async_add_custom_vs(self, engine, custom_vs): + # setting extra metadata to be indexed in separate column + for node in nodes: + node.metadata["len"] = len(node.text) + + await custom_vs.async_add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE_CUSTOM_VS}"') + assert len(results) == 4 + assert results[0]["len"] == 3 + assert results[0]["nullable_int_field"] == None + assert results[0]["nullable_str_field"] == None + + async def test_adelete(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + await vs.adelete(nodes[0].node_id) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + async def test_adelete_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + await vs.adelete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", + value="foo", + operator=FilterOperator.TEXT_MATCH, + ), + MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), + ], + condition=FilterCondition.OR, + ), + ) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 2 + + async def test_aget_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", + value="foo", + operator=FilterOperator.TEXT_MATCH, + ), + MetadataFilter( + key="text", + value="bar", + operator=FilterOperator.TEXT_MATCH, + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + + async def test_aquery(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = await vs.aquery(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_aclear(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_adelete + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + await vs.aclear() + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 0 + + async def test_add(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.add(nodes) + + async def test_get_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_nodes() + + async def test_query(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.query(VectorStoreQuery(query_str="foo")) + + async def test_delete(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete("test_ref_doc_id") + + async def test_delete_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete_nodes(["test_node_id"]) + + async def test_clear(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.clear() diff --git a/tests/test_async_vector_store_index.py b/tests/test_async_vector_store_index.py new file mode 100644 index 0000000..af86315 --- /dev/null +++ b/tests/test_async_vector_store_index.py @@ -0,0 +1,151 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.async_vector_store import AsyncPostgresVectorStore +from llama_index_cloud_sql_pg.indexes import ( + DEFAULT_INDEX_NAME_SUFFIX, + DistanceStrategy, + HNSWIndex, + IVFFlatIndex, +) + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX +VECTOR_SIZE = 5 + +texts = ["foo", "bar", "baz", "foobar"] +ids = [str(uuid.uuid4()) for i in range(len(texts))] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +embeddings = [[1.0] * VECTOR_SIZE for i in range(len(texts))] +nodes = [ + TextNode( + id_=ids[i], + text=texts[i], + embedding=embeddings[i], + metadata=metadatas[i], # type: ignore + ) + for i in range(len(texts)) +] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest.mark.asyncio(loop_scope="class") +class TestIndex: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = await AsyncPostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ) + + await vs.async_add(nodes) + await vs.adrop_vector_index() + yield vs + + async def test_aapply_vector_index(self, vs): + index = HNSWIndex() + await vs.aapply_vector_index(index) + assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_areindex(self, vs): + if not await vs.is_valid_index(DEFAULT_INDEX_NAME): + index = HNSWIndex() + await vs.aapply_vector_index(index) + await vs.areindex() + await vs.areindex(DEFAULT_INDEX_NAME) + assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_dropindex(self, vs): + await vs.adrop_vector_index() + result = await vs.is_valid_index(DEFAULT_INDEX_NAME) + assert not result + + async def test_aapply_vector_index_ivfflat(self, vs): + index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.is_valid_index(DEFAULT_INDEX_NAME) + index = IVFFlatIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + await vs.aapply_vector_index(index) + assert await vs.is_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + + async def test_is_valid_index(self, vs): + is_valid = await vs.is_valid_index("invalid_index") + assert is_valid == False diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py new file mode 100644 index 0000000..b6939b0 --- /dev/null +++ b/tests/test_vector_store.py @@ -0,0 +1,829 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from llama_index.core.vector_stores.types import ( + FilterCondition, + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, +) +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from llama_index_cloud_sql_pg import PostgresEngine +from llama_index_cloud_sql_pg.vector_store import PostgresVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +VECTOR_SIZE = 5 + +texts = ["foo", "bar", "baz", "foobar", "foobarbaz"] +embedding = [1.0] * VECTOR_SIZE +nodes = [ + TextNode( + id_=str(uuid.uuid4()), + text=texts[i], + embedding=[1 / (i + 1.0)] * VECTOR_SIZE, + metadata={ # type: ignore + "votes": [str(j) for j in range(i + 1)], + "other_texts": texts[0:i], + }, + ) + for i in range(len(texts)) +] +# setting each node as their own parent +for node in nodes: + node.relationships[NodeRelationship.SOURCE] = node.as_related_node_info() + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute(engine: PostgresEngine, query: str) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: PostgresEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + sync_engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield sync_engine + + await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await sync_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine.ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = await PostgresVectorStore.create(engine, table_name=DEFAULT_TABLE) + yield vs + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + + async def test_validate_id_column_create(self, engine, vs): + test_id_column = "test_id_column" + with pytest.raises( + Exception, match=f"Id column, {test_id_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ) + + async def test_validate_text_column_create(self, engine, vs): + test_text_column = "test_text_column" + with pytest.raises( + Exception, match=f"Text column, {test_text_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ) + + async def test_validate_embedding_column_create(self, engine, vs): + test_embed_column = "test_embed_column" + with pytest.raises( + Exception, match=f"Embedding column, {test_embed_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, embedding_column=test_embed_column + ) + + async def test_validate_node_column_create(self, engine, vs): + test_node_column = "test_node_column" + with pytest.raises( + Exception, match=f"Node column, {test_node_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ) + + async def test_validate_ref_doc_id_column_create(self, engine, vs): + test_ref_doc_id_column = "test_ref_doc_id_column" + with pytest.raises( + Exception, + match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", + ): + await PostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ) + + async def test_validate_metadata_json_column_create(self, engine, vs): + test_metadata_json_column = "test_metadata_json_column" + with pytest.raises( + Exception, + match=f"Metadata column, {test_metadata_json_column}, does not exist.", + ): + await PostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, + ) + + async def test_add(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 5 + + async def test_async_add(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 5 + + async def test_delete(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.delete(nodes[0].node_id) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 4 + + async def test_adelete(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.adelete(nodes[0].node_id) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 4 + + async def test_delete_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.delete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), + ], + condition=FilterCondition.OR, + ), + ) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + async def test_adelete_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.adelete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), + ], + condition=FilterCondition.OR, + ), + ) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + async def test_get_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter( + key="text", value="bar", operator=FilterOperator.TEXT_MATCH + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + + async def test_aget_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter( + key="text", value="bar", operator=FilterOperator.TEXT_MATCH + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + + async def test_aget_nodes_filter_1(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value=["foo", "fooz"], operator=FilterOperator.IN + ), + MetadataFilter( + key="text", value=["bar", "baarz"], operator=FilterOperator.NIN + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_get_nodes_filter_2(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="other_texts", + value="", + operator=FilterOperator.IS_EMPTY, + ), + ], + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_aget_nodes_filter_3(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + nodes[0].excluded_embed_metadata_keys = ["abc", "def"] + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="other_texts", + operator=FilterOperator.CONTAINS, + value="foobar", + ), + ], + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobarbaz" + + async def test_get_nodes_filter_4(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="votes", value=["3", "4"], operator=FilterOperator.ANY + ), + MetadataFilter( + key="votes", value=["3", "4"], operator=FilterOperator.ALL + ), + ], + condition=FilterCondition.OR, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + assert results[1].get_content(metadata_mode=MetadataMode.NONE) == "foobarbaz" + + async def test_aquery(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = await vs.aquery(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_query(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = vs.query(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_aclear(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_adelete + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.aclear() + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 0 + + async def test_clear(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_adelete + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.clear() + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 0 + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + sync_engine = PostgresEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield sync_engine + + await aexecute(sync_engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await sync_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + engine.init_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = PostgresVectorStore.create_sync(engine, table_name=DEFAULT_TABLE) + yield vs + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + PostgresVectorStore(engine, table_name=DEFAULT_TABLE) + + async def test_validate_id_column_create(self, engine, vs): + test_id_column = "test_id_column" + with pytest.raises( + Exception, match=f"Id column, {test_id_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ) + + async def test_validate_text_column_create(self, engine, vs): + test_text_column = "test_text_column" + with pytest.raises( + Exception, match=f"Text column, {test_text_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ) + + async def test_validate_embedding_column_create(self, engine, vs): + test_embed_column = "test_embed_column" + with pytest.raises( + Exception, match=f"Embedding column, {test_embed_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, embedding_column=test_embed_column + ) + + async def test_validate_node_column_create(self, engine, vs): + test_node_column = "test_node_column" + with pytest.raises( + Exception, match=f"Node column, {test_node_column}, does not exist." + ): + await PostgresVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ) + + async def test_validate_ref_doc_id_column_create(self, engine, vs): + test_ref_doc_id_column = "test_ref_doc_id_column" + with pytest.raises( + Exception, + match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", + ): + await PostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ) + + async def test_validate_metadata_json_column_create(self, engine, vs): + test_metadata_json_column = "test_metadata_json_column" + with pytest.raises( + Exception, + match=f"Metadata column, {test_metadata_json_column}, does not exist.", + ): + await PostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, + ) + + async def test_add(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 5 + + async def test_async_add(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + await vs.async_add(nodes) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 5 + + async def test_delete(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.delete(nodes[0].node_id) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 4 + + async def test_adelete(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.adelete(nodes[0].node_id) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 4 + + async def test_delete_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.delete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), + ], + condition=FilterCondition.OR, + ), + ) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + async def test_adelete_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.adelete_nodes( + node_ids=[nodes[0].node_id, nodes[1].node_id], + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter(key="text", value="bar", operator=FilterOperator.EQ), + ], + condition=FilterCondition.OR, + ), + ) + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 3 + + async def test_get_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter( + key="text", value="bar", operator=FilterOperator.TEXT_MATCH + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + + async def test_aget_nodes(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value="foo", operator=FilterOperator.TEXT_MATCH + ), + MetadataFilter( + key="text", value="bar", operator=FilterOperator.TEXT_MATCH + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + + async def test_aget_nodes_filter_1(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="text", value=["foo", "fooz"], operator=FilterOperator.IN + ), + MetadataFilter( + key="text", value=["bar", "baarz"], operator=FilterOperator.NIN + ), + ], + condition=FilterCondition.AND, + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_get_nodes_filter_2(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="other_texts", + value="", + operator=FilterOperator.IS_EMPTY, + ), + ], + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_aget_nodes_filter_3(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + nodes[0].excluded_embed_metadata_keys = ["abc", "def"] + vs.add(nodes) + results = await vs.aget_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="other_texts", + operator=FilterOperator.CONTAINS, + value="foobar", + ), + ], + ) + ) + + assert len(results) == 1 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobarbaz" + + async def test_get_nodes_filter_4(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + results = vs.get_nodes( + filters=MetadataFilters( + filters=[ + MetadataFilter( + key="votes", value=["3", "4"], operator=FilterOperator.ANY + ), + MetadataFilter( + key="votes", value=["3", "4"], operator=FilterOperator.ALL + ), + ], + condition=FilterCondition.OR, + ) + ) + + assert len(results) == 2 + assert results[0].get_content(metadata_mode=MetadataMode.NONE) == "foobar" + assert results[1].get_content(metadata_mode=MetadataMode.NONE) == "foobarbaz" + + async def test_aquery(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = await vs.aquery(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_query(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_async_add + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + query = VectorStoreQuery( + query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3 + ) + results = vs.query(query) + + assert results.nodes is not None + assert results.ids is not None + assert results.similarities is not None + assert len(results.nodes) == 3 + assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo" + + async def test_aclear(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_adelete + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + await vs.aclear() + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 0 + + async def test_clear(self, engine, vs): + # Note: To be migrated to a pytest dependency on test_adelete + # Blocked due to unexpected fixtures reloads while running integration test suite + await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') + vs.add(nodes) + vs.clear() + + results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') + assert len(results) == 0 diff --git a/tests/test_vector_store_index.py b/tests/test_vector_store_index.py new file mode 100644 index 0000000..4c39cec --- /dev/null +++ b/tests/test_vector_store_index.py @@ -0,0 +1,262 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import MetadataMode, NodeRelationship, TextNode +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping +from sqlalchemy.ext.asyncio import create_async_engine + +from llama_index_cloud_sql_pg import PostgresEngine, PostgresVectorStore +from llama_index_cloud_sql_pg.indexes import ( + DEFAULT_INDEX_NAME_SUFFIX, + DistanceStrategy, + HNSWIndex, + IVFFlatIndex, + IVFIndex, +) +from llama_index_cloud_sql_pg.vector_store import PostgresVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_TABLE_ASYNC = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +DEFAULT_INDEX_NAME = DEFAULT_TABLE + DEFAULT_INDEX_NAME_SUFFIX +DEFAULT_INDEX_NAME_ASYNC = DEFAULT_TABLE_ASYNC + DEFAULT_INDEX_NAME_SUFFIX +VECTOR_SIZE = 5 + + +texts = ["foo", "bar", "baz", "foobar", "foobaz"] +ids = [str(uuid.uuid4()) for i in range(len(texts))] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +embeddings = [[1.0] * VECTOR_SIZE for i in range(len(texts))] +nodes = [ + TextNode( + id_=ids[i], + text=texts[i], + embedding=embeddings[i], + metadata=metadatas[i], # type: ignore + ) + for i in range(len(texts)) +] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute( + engine: PostgresEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +@pytest.mark.asyncio(loop_scope="class") +class TestIndexSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + engine.init_vector_store_table(DEFAULT_TABLE, VECTOR_SIZE) + vs = PostgresVectorStore.create_sync( + engine, + table_name=DEFAULT_TABLE, + ) + + await vs.async_add(nodes) + + vs.drop_vector_index() + yield vs + + async def test_aapply_vector_index(self, vs): + index = HNSWIndex() + vs.apply_vector_index(index) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_areindex(self, vs): + if not vs.is_valid_index(DEFAULT_INDEX_NAME): + index = HNSWIndex() + vs.apply_vector_index(index) + vs.reindex() + vs.reindex(DEFAULT_INDEX_NAME) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + + async def test_dropindex(self, vs): + vs.drop_vector_index() + result = vs.is_valid_index(DEFAULT_INDEX_NAME) + assert not result + + async def test_aapply_vector_index_ivfflat(self, vs): + index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + vs.apply_vector_index(index, concurrently=True) + assert vs.is_valid_index(DEFAULT_INDEX_NAME) + index = IVFFlatIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + vs.apply_vector_index(index) + assert vs.is_valid_index("secondindex") + vs.drop_vector_index("secondindex") + + async def test_is_valid_index(self, vs): + is_valid = vs.is_valid_index("invalid_index") + assert is_valid == False + + +@pytest.mark.asyncio(loop_scope="class") +class TestAsyncIndex: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "instance for Cloud SQL") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on Cloud SQL instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on Cloud SQL instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_instance, db_name): + engine = await PostgresEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine.ainit_vector_store_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) + vs = await PostgresVectorStore.create( + engine, + table_name=DEFAULT_TABLE_ASYNC, + ) + + await vs.async_add(nodes) + await vs.adrop_vector_index() + yield vs + + async def test_aapply_vector_index(self, vs): + index = HNSWIndex() + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + + async def test_areindex(self, vs): + if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): + index = HNSWIndex() + await vs.aapply_vector_index(index) + await vs.areindex() + await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + + async def test_dropindex(self, vs): + await vs.adrop_vector_index() + result = await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + assert not result + + async def test_aapply_vector_index_ivfflat(self, vs): + index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + index = IVFFlatIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + await vs.adrop_vector_index() + + async def test_is_valid_index(self, vs): + is_valid = await vs.ais_valid_index("invalid_index") + assert is_valid == False + + async def test_aapply_vector_index_ivf(self, vs): + index = IVFIndex(distance_strategy=DistanceStrategy.EUCLIDEAN) + await vs.aapply_vector_index(index, concurrently=True) + assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) + index = IVFIndex( + name="secondindex", + distance_strategy=DistanceStrategy.INNER_PRODUCT, + ) + await vs.aapply_vector_index(index) + assert await vs.ais_valid_index("secondindex") + await vs.adrop_vector_index("secondindex") + await vs.adrop_vector_index()