From 8d512ab5741ecb4ae061ad2f3b42f8200f9cdc73 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 21 Aug 2023 14:09:51 -0700 Subject: [PATCH 1/2] Updated to use the 0.9.0 SDK --- pgml-apps/pgml-chat/pgml_chat/main.py | 167 +++++--------------------- pgml-apps/pgml-chat/pyproject.toml | 2 +- 2 files changed, 30 insertions(+), 139 deletions(-) diff --git a/pgml-apps/pgml-chat/pgml_chat/main.py b/pgml-apps/pgml-chat/pgml_chat/main.py index a43d0653d..51bf7d8d1 100644 --- a/pgml-apps/pgml-chat/pgml_chat/main.py +++ b/pgml-apps/pgml-chat/pgml_chat/main.py @@ -1,5 +1,5 @@ import asyncio -from pgml import Database +from pgml import Collection, Model, Splitter, Pipeline import logging from rich.logging import RichHandler from rich.progress import track @@ -77,26 +77,28 @@ def handler(signum, frame): # The code is using the `argparse` module to parse command line arguments. -collection_name = args.collection_name +collection = Collection(args.collection_name) stage = args.stage chat_interface = args.chat_interface # The above code is retrieving environment variables and assigning their values to various variables. database_url = os.environ.get("DATABASE_URL") -db = Database(database_url) -splitter = os.environ.get("SPLITTER", "recursive_character") +splitter_name = os.environ.get("SPLITTER", "recursive_character") splitter_params = os.environ.get( "SPLITTER_PARAMS", {"chunk_size": 1500, "chunk_overlap": 40} ) -model = os.environ.get("MODEL", "intfloat/e5-small") +splitter = Splitter(splitter_name, splitter_params) +model_name = os.environ.get("MODEL", "intfloat/e5-small") model_params = ast.literal_eval(os.environ.get("MODEL_PARAMS", {})) +model = Model(model_name, "pgml", model_params) +pipeline = Pipeline(args.collection_name + "_pipeline", model, splitter) query_params = ast.literal_eval(os.environ.get("QUERY_PARAMS", {})) system_prompt = os.environ.get("SYSTEM_PROMPT") base_prompt = os.environ.get("BASE_PROMPT") openai_api_key = os.environ.get("OPENAI_API_KEY") -async def upsert_documents(db: Database, collection_name: str, folder: str) -> int: +async def upsert_documents(folder: str) -> int: log.info("Scanning " + folder + " for markdown files") md_files = [] # root_dir needs a trailing slash (i.e. /root/dir/) @@ -107,100 +109,14 @@ async def upsert_documents(db: Database, collection_name: str, folder: str) -> i documents = [] for md_file in track(md_files, description="Extracting text from markdown"): with open(md_file, "r") as f: - documents.append({"text": f.read(), "filename": md_file}) + documents.append({"text": f.read(), "id": md_file}) log.info("Upserting documents into database") - collection = await db.create_or_get_collection(collection_name) await collection.upsert_documents(documents) return len(md_files) -async def generate_chunks( - db: Database, - collection_name: str, - splitter: str = "recursive_character", - splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40}, -) -> int: - """ - The function `generate_chunks` generates chunks for a given collection in a database and returns the - count of chunks created. - - :param db: The `db` parameter is an instance of a database connection or client. It is used to - interact with the database and perform operations such as creating collections, executing queries, - and fetching results - :type db: Database - :param collection_name: The `collection_name` parameter is a string that represents the name of the - collection in the database. It is used to create or get the collection and perform operations on it - :type collection_name: str - :return: The function `generate_chunks` returns an integer, which represents the count of chunks - generated in the specified collection. - """ - log.info("Generating chunks") - collection = await db.create_or_get_collection(collection_name) - splitter_id = await collection.register_text_splitter(splitter, splitter_params) - query_string = """SELECT count(*) from {collection_name}.chunks""".format( - collection_name=collection_name - ) - results = await db.query(query_string).fetch_all() - start_chunks = results[0]["count"] - log.info("Starting chunk count: " + str(start_chunks)) - await collection.generate_chunks(splitter_id) - results = await db.query(query_string).fetch_all() - log.info("Ending chunk count: " + str(results[0]["count"])) - return results[0]["count"] - start_chunks - - -async def generate_embeddings( - db: Database, - collection_name: str, - splitter: str = "recursive_character", - splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40}, - model: str = "intfloat/e5-small", - model_params: dict = {}, -) -> int: - """ - The `generate_embeddings` function generates embeddings for text data using a specified model and - splitter. - - :param db: The `db` parameter is an instance of a database object. It is used to interact with the - database and perform operations such as creating or getting a collection, registering a text - splitter, registering a model, and generating embeddings - :type db: Database - :param collection_name: The `collection_name` parameter is a string that represents the name of the - collection in the database where the embeddings will be generated - :type collection_name: str - :param splitter: The `splitter` parameter is used to specify the text splitting method to be used - during the embedding generation process. In this case, the value is set to "recursive_character", - which suggests that the text will be split into chunks based on recursive character splitting, - defaults to recursive_character - :type splitter: str (optional) - :param splitter_params: The `splitter_params` parameter is a dictionary that contains the parameters - for the text splitter. In this case, the `splitter_params` dictionary has two keys: - :type splitter_params: dict - :param model: The `model` parameter is the name or identifier of the language model that will be - used to generate the embeddings. In this case, the model is specified as "intfloat/e5-small", - defaults to intfloat/e5-small - :type model: str (optional) - :param model_params: The `model_params` parameter is a dictionary that allows you to specify - additional parameters for the model. These parameters can be used to customize the behavior of the - model during the embedding generation process. The specific parameters that can be included in the - `model_params` dictionary will depend on the specific model you are - :type model_params: dict - :return: an integer value of 0. - """ - log.info("Generating embeddings") - collection = await db.create_or_get_collection(collection_name) - splitter_id = await collection.register_text_splitter(splitter, splitter_params) - model_id = await collection.register_model("embedding", model, model_params) - log.info("Splitter ID: " + str(splitter_id)) - start = time() - await collection.generate_embeddings(model_id, splitter_id) - log.info("Embeddings generated in %0.3f seconds" % (time() - start)) - - return 0 - - async def generate_response( messages, openai_api_key, temperature=0.7, max_tokens=256, top_p=0.9 ): @@ -217,44 +133,20 @@ async def generate_response( return response["choices"][0]["message"]["content"] -async def ingest_documents( - db: Database, - collection_name: str, - folder: str, - splitter: str, - splitter_params: dict, - model: str, - model_params: dict, -): - total_docs = await upsert_documents(db, collection_name, folder=folder) - total_chunks = await generate_chunks( - db, collection_name, splitter=splitter, splitter_params=splitter_params - ) - log.info( - "Total documents: " + str(total_docs) + " Total chunks: " + str(total_chunks) - ) - - await generate_embeddings( - db, - collection_name, - splitter=splitter, - splitter_params=splitter_params, - model=model, - model_params=model_params, - ) +async def ingest_documents(folder: str): + # Add the pipeline to the collection, does nothing if we have already added it + await collection.add_pipeline(pipeline) + # This will upsert, chunk, and embed the contents in the folder + total_docs = await upsert_documents(folder) + log.info("Total documents: " + str(total_docs)) async def get_prompt(user_input: str = ""): - collection = await db.create_or_get_collection(collection_name) - model_id = await collection.register_model("embedding", model, model_params) - splitter_id = await collection.register_text_splitter(splitter, splitter_params) - log.info("Model id: " + str(model_id) + " Splitter id: " + str(splitter_id)) - vector_results = await collection.vector_search( - user_input, - model_id=model_id, - splitter_id=splitter_id, - top_k=2, - query_params=query_params, + vector_results = ( + await collection.query() + .vector_recall(user_input, pipeline, query_params) + .limit(2) + .fetch_all() ) log.info(vector_results) context = "" @@ -322,10 +214,12 @@ async def message_hello(message, say): intents.message_content = True client = discord.Client(intents=intents) + @client.event async def on_ready(): print(f"We have logged in as {client.user}") + @client.event async def on_message(message): bot_mention = f"<@{client.user.id}>" @@ -351,15 +245,7 @@ async def run(): if stage == "ingest": root_dir = args.root_dir - await ingest_documents( - db, - collection_name, - root_dir, - splitter, - splitter_params, - model, - model_params, - ) + await ingest_documents(root_dir) elif stage == "chat": if chat_interface == "cli": @@ -369,7 +255,12 @@ async def run(): def main(): - if stage == "chat" and chat_interface == "discord" and os.environ.get("DISCORD_BOT_TOKEN"): + if ( + stage == "chat" + and chat_interface == "discord" + and os.environ.get("DISCORD_BOT_TOKEN") + ): client.run(os.environ["DISCORD_BOT_TOKEN"]) else: asyncio.run(run()) +main() diff --git a/pgml-apps/pgml-chat/pyproject.toml b/pgml-apps/pgml-chat/pyproject.toml index ae3254f3e..10f9c95e9 100644 --- a/pgml-apps/pgml-chat/pyproject.toml +++ b/pgml-apps/pgml-chat/pyproject.toml @@ -11,7 +11,7 @@ packages = [{include = "pgml_chat"}] python = ">=3.8,<4.0" openai = "^0.27.8" rich = "^13.4.2" -pgml = "^0.8.0" +pgml = "^0.9.0" python-dotenv = "^1.0.0" click = "^8.1.6" black = "^23.7.0" From abe98fc916e6cd763f4a5be833d5648cd2efb52a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 21 Aug 2023 14:25:09 -0700 Subject: [PATCH 2/2] Updated a tiny bit of text --- pgml-apps/pgml-chat/README.md | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/pgml-apps/pgml-chat/README.md b/pgml-apps/pgml-chat/README.md index 5f4b13a68..56fcf5783 100644 --- a/pgml-apps/pgml-chat/README.md +++ b/pgml-apps/pgml-chat/README.md @@ -71,22 +71,8 @@ In this step, we ingest documents, chunk documents, generate embeddings and inde LOG_LEVEL=DEBUG pgml-chat --root_dir --collection_name --stage ingest ``` -You will see the following output: -```bash -[15:39:12] DEBUG [15:39:12] - Using selector: KqueueSelector - INFO [15:39:12] - Starting pgml_chatbot - INFO [15:39:12] - Scanning for markdown files -[15:39:13] INFO [15:39:13] - Found 85 markdown files -Extracting text from markdown ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 - INFO [15:39:13] - Upserting documents into database -[15:39:32] INFO [15:39:32] - Generating chunks -[15:39:33] INFO [15:39:33] - Starting chunk count: 0 -[15:39:35] INFO [15:39:35] - Ending chunk count: 576 - INFO [15:39:35] - Total documents: 85 Total chunks: 576 - INFO [15:39:35] - Generating embeddings -[15:39:36] INFO [15:39:36] - Splitter ID: 2 -[15:40:47] INFO [15:40:47] - Embeddings generated in 71.073 seconds -``` +You will see output logging the pipelines progress. + ## Chat You can interact with the bot using the command line interface or Slack. @@ -185,4 +171,4 @@ You can control the behavior of the chatbot by setting the following environment - ~~`hyerbot --chat_interface {cli, slack, discord}` that supports Slack, and Discord.~~ - Support for file formats like rst, html, pdf, docx, etc. - Support for open source models in addition to OpenAI for chat completion. -- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses. \ No newline at end of file +- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses.