Skip to content

Commit 85341f2

Browse files
authored
Updated to use the 0.9.0 SDK (#936)
1 parent bd166e8 commit 85341f2

File tree

3 files changed

+33
-156
lines changed

3 files changed

+33
-156
lines changed

pgml-apps/pgml-chat/README.md

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,8 @@ In this step, we ingest documents, chunk documents, generate embeddings and inde
7171
LOG_LEVEL=DEBUG pgml-chat --root_dir <directory> --collection_name <collection_name> --stage ingest
7272
```
7373

74-
You will see the following output:
75-
```bash
76-
[15:39:12] DEBUG [15:39:12] - Using selector: KqueueSelector
77-
INFO [15:39:12] - Starting pgml_chatbot
78-
INFO [15:39:12] - Scanning <root directory> for markdown files
79-
[15:39:13] INFO [15:39:13] - Found 85 markdown files
80-
Extracting text from markdown ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
81-
INFO [15:39:13] - Upserting documents into database
82-
[15:39:32] INFO [15:39:32] - Generating chunks
83-
[15:39:33] INFO [15:39:33] - Starting chunk count: 0
84-
[15:39:35] INFO [15:39:35] - Ending chunk count: 576
85-
INFO [15:39:35] - Total documents: 85 Total chunks: 576
86-
INFO [15:39:35] - Generating embeddings
87-
[15:39:36] INFO [15:39:36] - Splitter ID: 2
88-
[15:40:47] INFO [15:40:47] - Embeddings generated in 71.073 seconds
89-
```
74+
You will see output logging the pipelines progress.
75+
9076
## Chat
9177
You can interact with the bot using the command line interface or Slack.
9278

@@ -185,4 +171,4 @@ You can control the behavior of the chatbot by setting the following environment
185171
- ~~`hyerbot --chat_interface {cli, slack, discord}` that supports Slack, and Discord.~~
186172
- Support for file formats like rst, html, pdf, docx, etc.
187173
- Support for open source models in addition to OpenAI for chat completion.
188-
- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses.
174+
- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses.

pgml-apps/pgml-chat/pgml_chat/main.py

Lines changed: 29 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from pgml import Database
2+
from pgml import Collection, Model, Splitter, Pipeline
33
import logging
44
from rich.logging import RichHandler
55
from rich.progress import track
@@ -77,26 +77,28 @@ def handler(signum, frame):
7777

7878

7979
# The code is using the `argparse` module to parse command line arguments.
80-
collection_name = args.collection_name
80+
collection = Collection(args.collection_name)
8181
stage = args.stage
8282
chat_interface = args.chat_interface
8383

8484
# The above code is retrieving environment variables and assigning their values to various variables.
8585
database_url = os.environ.get("DATABASE_URL")
86-
db = Database(database_url)
87-
splitter = os.environ.get("SPLITTER", "recursive_character")
86+
splitter_name = os.environ.get("SPLITTER", "recursive_character")
8887
splitter_params = os.environ.get(
8988
"SPLITTER_PARAMS", {"chunk_size": 1500, "chunk_overlap": 40}
9089
)
91-
model = os.environ.get("MODEL", "intfloat/e5-small")
90+
splitter = Splitter(splitter_name, splitter_params)
91+
model_name = os.environ.get("MODEL", "intfloat/e5-small")
9292
model_params = ast.literal_eval(os.environ.get("MODEL_PARAMS", {}))
93+
model = Model(model_name, "pgml", model_params)
94+
pipeline = Pipeline(args.collection_name + "_pipeline", model, splitter)
9395
query_params = ast.literal_eval(os.environ.get("QUERY_PARAMS", {}))
9496
system_prompt = os.environ.get("SYSTEM_PROMPT")
9597
base_prompt = os.environ.get("BASE_PROMPT")
9698
openai_api_key = os.environ.get("OPENAI_API_KEY")
9799

98100

99-
async def upsert_documents(db: Database, collection_name: str, folder: str) -> int:
101+
async def upsert_documents(folder: str) -> int:
100102
log.info("Scanning " + folder + " for markdown files")
101103
md_files = []
102104
# root_dir needs a trailing slash (i.e. /root/dir/)
@@ -107,100 +109,14 @@ async def upsert_documents(db: Database, collection_name: str, folder: str) -> i
107109
documents = []
108110
for md_file in track(md_files, description="Extracting text from markdown"):
109111
with open(md_file, "r") as f:
110-
documents.append({"text": f.read(), "filename": md_file})
112+
documents.append({"text": f.read(), "id": md_file})
111113

112114
log.info("Upserting documents into database")
113-
collection = await db.create_or_get_collection(collection_name)
114115
await collection.upsert_documents(documents)
115116

116117
return len(md_files)
117118

118119

119-
async def generate_chunks(
120-
db: Database,
121-
collection_name: str,
122-
splitter: str = "recursive_character",
123-
splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40},
124-
) -> int:
125-
"""
126-
The function `generate_chunks` generates chunks for a given collection in a database and returns the
127-
count of chunks created.
128-
129-
:param db: The `db` parameter is an instance of a database connection or client. It is used to
130-
interact with the database and perform operations such as creating collections, executing queries,
131-
and fetching results
132-
:type db: Database
133-
:param collection_name: The `collection_name` parameter is a string that represents the name of the
134-
collection in the database. It is used to create or get the collection and perform operations on it
135-
:type collection_name: str
136-
:return: The function `generate_chunks` returns an integer, which represents the count of chunks
137-
generated in the specified collection.
138-
"""
139-
log.info("Generating chunks")
140-
collection = await db.create_or_get_collection(collection_name)
141-
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
142-
query_string = """SELECT count(*) from {collection_name}.chunks""".format(
143-
collection_name=collection_name
144-
)
145-
results = await db.query(query_string).fetch_all()
146-
start_chunks = results[0]["count"]
147-
log.info("Starting chunk count: " + str(start_chunks))
148-
await collection.generate_chunks(splitter_id)
149-
results = await db.query(query_string).fetch_all()
150-
log.info("Ending chunk count: " + str(results[0]["count"]))
151-
return results[0]["count"] - start_chunks
152-
153-
154-
async def generate_embeddings(
155-
db: Database,
156-
collection_name: str,
157-
splitter: str = "recursive_character",
158-
splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40},
159-
model: str = "intfloat/e5-small",
160-
model_params: dict = {},
161-
) -> int:
162-
"""
163-
The `generate_embeddings` function generates embeddings for text data using a specified model and
164-
splitter.
165-
166-
:param db: The `db` parameter is an instance of a database object. It is used to interact with the
167-
database and perform operations such as creating or getting a collection, registering a text
168-
splitter, registering a model, and generating embeddings
169-
:type db: Database
170-
:param collection_name: The `collection_name` parameter is a string that represents the name of the
171-
collection in the database where the embeddings will be generated
172-
:type collection_name: str
173-
:param splitter: The `splitter` parameter is used to specify the text splitting method to be used
174-
during the embedding generation process. In this case, the value is set to "recursive_character",
175-
which suggests that the text will be split into chunks based on recursive character splitting,
176-
defaults to recursive_character
177-
:type splitter: str (optional)
178-
:param splitter_params: The `splitter_params` parameter is a dictionary that contains the parameters
179-
for the text splitter. In this case, the `splitter_params` dictionary has two keys:
180-
:type splitter_params: dict
181-
:param model: The `model` parameter is the name or identifier of the language model that will be
182-
used to generate the embeddings. In this case, the model is specified as "intfloat/e5-small",
183-
defaults to intfloat/e5-small
184-
:type model: str (optional)
185-
:param model_params: The `model_params` parameter is a dictionary that allows you to specify
186-
additional parameters for the model. These parameters can be used to customize the behavior of the
187-
model during the embedding generation process. The specific parameters that can be included in the
188-
`model_params` dictionary will depend on the specific model you are
189-
:type model_params: dict
190-
:return: an integer value of 0.
191-
"""
192-
log.info("Generating embeddings")
193-
collection = await db.create_or_get_collection(collection_name)
194-
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
195-
model_id = await collection.register_model("embedding", model, model_params)
196-
log.info("Splitter ID: " + str(splitter_id))
197-
start = time()
198-
await collection.generate_embeddings(model_id, splitter_id)
199-
log.info("Embeddings generated in %0.3f seconds" % (time() - start))
200-
201-
return 0
202-
203-
204120
async def generate_response(
205121
messages, openai_api_key, temperature=0.7, max_tokens=256, top_p=0.9
206122
):
@@ -217,44 +133,20 @@ async def generate_response(
217133
return response["choices"][0]["message"]["content"]
218134

219135

220-
async def ingest_documents(
221-
db: Database,
222-
collection_name: str,
223-
folder: str,
224-
splitter: str,
225-
splitter_params: dict,
226-
model: str,
227-
model_params: dict,
228-
):
229-
total_docs = await upsert_documents(db, collection_name, folder=folder)
230-
total_chunks = await generate_chunks(
231-
db, collection_name, splitter=splitter, splitter_params=splitter_params
232-
)
233-
log.info(
234-
"Total documents: " + str(total_docs) + " Total chunks: " + str(total_chunks)
235-
)
236-
237-
await generate_embeddings(
238-
db,
239-
collection_name,
240-
splitter=splitter,
241-
splitter_params=splitter_params,
242-
model=model,
243-
model_params=model_params,
244-
)
136+
async def ingest_documents(folder: str):
137+
# Add the pipeline to the collection, does nothing if we have already added it
138+
await collection.add_pipeline(pipeline)
139+
# This will upsert, chunk, and embed the contents in the folder
140+
total_docs = await upsert_documents(folder)
141+
log.info("Total documents: " + str(total_docs))
245142

246143

247144
async def get_prompt(user_input: str = ""):
248-
collection = await db.create_or_get_collection(collection_name)
249-
model_id = await collection.register_model("embedding", model, model_params)
250-
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
251-
log.info("Model id: " + str(model_id) + " Splitter id: " + str(splitter_id))
252-
vector_results = await collection.vector_search(
253-
user_input,
254-
model_id=model_id,
255-
splitter_id=splitter_id,
256-
top_k=2,
257-
query_params=query_params,
145+
vector_results = (
146+
await collection.query()
147+
.vector_recall(user_input, pipeline, query_params)
148+
.limit(2)
149+
.fetch_all()
258150
)
259151
log.info(vector_results)
260152
context = ""
@@ -322,10 +214,12 @@ async def message_hello(message, say):
322214
intents.message_content = True
323215
client = discord.Client(intents=intents)
324216

217+
325218
@client.event
326219
async def on_ready():
327220
print(f"We have logged in as {client.user}")
328221

222+
329223
@client.event
330224
async def on_message(message):
331225
bot_mention = f"<@{client.user.id}>"
@@ -351,15 +245,7 @@ async def run():
351245

352246
if stage == "ingest":
353247
root_dir = args.root_dir
354-
await ingest_documents(
355-
db,
356-
collection_name,
357-
root_dir,
358-
splitter,
359-
splitter_params,
360-
model,
361-
model_params,
362-
)
248+
await ingest_documents(root_dir)
363249

364250
elif stage == "chat":
365251
if chat_interface == "cli":
@@ -369,7 +255,12 @@ async def run():
369255

370256

371257
def main():
372-
if stage == "chat" and chat_interface == "discord" and os.environ.get("DISCORD_BOT_TOKEN"):
258+
if (
259+
stage == "chat"
260+
and chat_interface == "discord"
261+
and os.environ.get("DISCORD_BOT_TOKEN")
262+
):
373263
client.run(os.environ["DISCORD_BOT_TOKEN"])
374264
else:
375265
asyncio.run(run())
266+
main()

pgml-apps/pgml-chat/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ packages = [{include = "pgml_chat"}]
1111
python = ">=3.8,<4.0"
1212
openai = "^0.27.8"
1313
rich = "^13.4.2"
14-
pgml = "^0.8.0"
14+
pgml = "^0.9.0"
1515
python-dotenv = "^1.0.0"
1616
click = "^8.1.6"
1717
black = "^23.7.0"

0 commit comments

Comments
 (0)