Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions pgml-apps/pgml-chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,8 @@ In this step, we ingest documents, chunk documents, generate embeddings and inde
LOG_LEVEL=DEBUG pgml-chat --root_dir <directory> --collection_name <collection_name> --stage ingest
```

You will see the following output:
```bash
[15:39:12] DEBUG [15:39:12] - Using selector: KqueueSelector
INFO [15:39:12] - Starting pgml_chatbot
INFO [15:39:12] - Scanning <root directory> for markdown files
[15:39:13] INFO [15:39:13] - Found 85 markdown files
Extracting text from markdown ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
INFO [15:39:13] - Upserting documents into database
[15:39:32] INFO [15:39:32] - Generating chunks
[15:39:33] INFO [15:39:33] - Starting chunk count: 0
[15:39:35] INFO [15:39:35] - Ending chunk count: 576
INFO [15:39:35] - Total documents: 85 Total chunks: 576
INFO [15:39:35] - Generating embeddings
[15:39:36] INFO [15:39:36] - Splitter ID: 2
[15:40:47] INFO [15:40:47] - Embeddings generated in 71.073 seconds
```
You will see output logging the pipelines progress.

## Chat
You can interact with the bot using the command line interface or Slack.

Expand Down Expand Up @@ -185,4 +171,4 @@ You can control the behavior of the chatbot by setting the following environment
- ~~`hyerbot --chat_interface {cli, slack, discord}` that supports Slack, and Discord.~~
- Support for file formats like rst, html, pdf, docx, etc.
- Support for open source models in addition to OpenAI for chat completion.
- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses.
- Support for multi-turn converstaions using converstaion buffer. Use a collection for chat history that can be retrieved and used to generate responses.
167 changes: 29 additions & 138 deletions pgml-apps/pgml-chat/pgml_chat/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from pgml import Database
from pgml import Collection, Model, Splitter, Pipeline
import logging
from rich.logging import RichHandler
from rich.progress import track
Expand Down Expand Up @@ -77,26 +77,28 @@ def handler(signum, frame):


# The code is using the `argparse` module to parse command line arguments.
collection_name = args.collection_name
collection = Collection(args.collection_name)
stage = args.stage
chat_interface = args.chat_interface

# The above code is retrieving environment variables and assigning their values to various variables.
database_url = os.environ.get("DATABASE_URL")
db = Database(database_url)
splitter = os.environ.get("SPLITTER", "recursive_character")
splitter_name = os.environ.get("SPLITTER", "recursive_character")
splitter_params = os.environ.get(
"SPLITTER_PARAMS", {"chunk_size": 1500, "chunk_overlap": 40}
)
model = os.environ.get("MODEL", "intfloat/e5-small")
splitter = Splitter(splitter_name, splitter_params)
model_name = os.environ.get("MODEL", "intfloat/e5-small")
model_params = ast.literal_eval(os.environ.get("MODEL_PARAMS", {}))
model = Model(model_name, "pgml", model_params)
pipeline = Pipeline(args.collection_name + "_pipeline", model, splitter)
query_params = ast.literal_eval(os.environ.get("QUERY_PARAMS", {}))
system_prompt = os.environ.get("SYSTEM_PROMPT")
base_prompt = os.environ.get("BASE_PROMPT")
openai_api_key = os.environ.get("OPENAI_API_KEY")


async def upsert_documents(db: Database, collection_name: str, folder: str) -> int:
async def upsert_documents(folder: str) -> int:
log.info("Scanning " + folder + " for markdown files")
md_files = []
# root_dir needs a trailing slash (i.e. /root/dir/)
Expand All @@ -107,100 +109,14 @@ async def upsert_documents(db: Database, collection_name: str, folder: str) -> i
documents = []
for md_file in track(md_files, description="Extracting text from markdown"):
with open(md_file, "r") as f:
documents.append({"text": f.read(), "filename": md_file})
documents.append({"text": f.read(), "id": md_file})

log.info("Upserting documents into database")
collection = await db.create_or_get_collection(collection_name)
await collection.upsert_documents(documents)

return len(md_files)


async def generate_chunks(
db: Database,
collection_name: str,
splitter: str = "recursive_character",
splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40},
) -> int:
"""
The function `generate_chunks` generates chunks for a given collection in a database and returns the
count of chunks created.

:param db: The `db` parameter is an instance of a database connection or client. It is used to
interact with the database and perform operations such as creating collections, executing queries,
and fetching results
:type db: Database
:param collection_name: The `collection_name` parameter is a string that represents the name of the
collection in the database. It is used to create or get the collection and perform operations on it
:type collection_name: str
:return: The function `generate_chunks` returns an integer, which represents the count of chunks
generated in the specified collection.
"""
log.info("Generating chunks")
collection = await db.create_or_get_collection(collection_name)
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
query_string = """SELECT count(*) from {collection_name}.chunks""".format(
collection_name=collection_name
)
results = await db.query(query_string).fetch_all()
start_chunks = results[0]["count"]
log.info("Starting chunk count: " + str(start_chunks))
await collection.generate_chunks(splitter_id)
results = await db.query(query_string).fetch_all()
log.info("Ending chunk count: " + str(results[0]["count"]))
return results[0]["count"] - start_chunks


async def generate_embeddings(
db: Database,
collection_name: str,
splitter: str = "recursive_character",
splitter_params: dict = {"chunk_size": 1500, "chunk_overlap": 40},
model: str = "intfloat/e5-small",
model_params: dict = {},
) -> int:
"""
The `generate_embeddings` function generates embeddings for text data using a specified model and
splitter.

:param db: The `db` parameter is an instance of a database object. It is used to interact with the
database and perform operations such as creating or getting a collection, registering a text
splitter, registering a model, and generating embeddings
:type db: Database
:param collection_name: The `collection_name` parameter is a string that represents the name of the
collection in the database where the embeddings will be generated
:type collection_name: str
:param splitter: The `splitter` parameter is used to specify the text splitting method to be used
during the embedding generation process. In this case, the value is set to "recursive_character",
which suggests that the text will be split into chunks based on recursive character splitting,
defaults to recursive_character
:type splitter: str (optional)
:param splitter_params: The `splitter_params` parameter is a dictionary that contains the parameters
for the text splitter. In this case, the `splitter_params` dictionary has two keys:
:type splitter_params: dict
:param model: The `model` parameter is the name or identifier of the language model that will be
used to generate the embeddings. In this case, the model is specified as "intfloat/e5-small",
defaults to intfloat/e5-small
:type model: str (optional)
:param model_params: The `model_params` parameter is a dictionary that allows you to specify
additional parameters for the model. These parameters can be used to customize the behavior of the
model during the embedding generation process. The specific parameters that can be included in the
`model_params` dictionary will depend on the specific model you are
:type model_params: dict
:return: an integer value of 0.
"""
log.info("Generating embeddings")
collection = await db.create_or_get_collection(collection_name)
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
model_id = await collection.register_model("embedding", model, model_params)
log.info("Splitter ID: " + str(splitter_id))
start = time()
await collection.generate_embeddings(model_id, splitter_id)
log.info("Embeddings generated in %0.3f seconds" % (time() - start))

return 0


async def generate_response(
messages, openai_api_key, temperature=0.7, max_tokens=256, top_p=0.9
):
Expand All @@ -217,44 +133,20 @@ async def generate_response(
return response["choices"][0]["message"]["content"]


async def ingest_documents(
db: Database,
collection_name: str,
folder: str,
splitter: str,
splitter_params: dict,
model: str,
model_params: dict,
):
total_docs = await upsert_documents(db, collection_name, folder=folder)
total_chunks = await generate_chunks(
db, collection_name, splitter=splitter, splitter_params=splitter_params
)
log.info(
"Total documents: " + str(total_docs) + " Total chunks: " + str(total_chunks)
)

await generate_embeddings(
db,
collection_name,
splitter=splitter,
splitter_params=splitter_params,
model=model,
model_params=model_params,
)
async def ingest_documents(folder: str):
# Add the pipeline to the collection, does nothing if we have already added it
await collection.add_pipeline(pipeline)
# This will upsert, chunk, and embed the contents in the folder
total_docs = await upsert_documents(folder)
log.info("Total documents: " + str(total_docs))


async def get_prompt(user_input: str = ""):
collection = await db.create_or_get_collection(collection_name)
model_id = await collection.register_model("embedding", model, model_params)
splitter_id = await collection.register_text_splitter(splitter, splitter_params)
log.info("Model id: " + str(model_id) + " Splitter id: " + str(splitter_id))
vector_results = await collection.vector_search(
user_input,
model_id=model_id,
splitter_id=splitter_id,
top_k=2,
query_params=query_params,
vector_results = (
await collection.query()
.vector_recall(user_input, pipeline, query_params)
.limit(2)
.fetch_all()
)
log.info(vector_results)
context = ""
Expand Down Expand Up @@ -322,10 +214,12 @@ async def message_hello(message, say):
intents.message_content = True
client = discord.Client(intents=intents)


@client.event
async def on_ready():
print(f"We have logged in as {client.user}")


@client.event
async def on_message(message):
bot_mention = f"<@{client.user.id}>"
Expand All @@ -351,15 +245,7 @@ async def run():

if stage == "ingest":
root_dir = args.root_dir
await ingest_documents(
db,
collection_name,
root_dir,
splitter,
splitter_params,
model,
model_params,
)
await ingest_documents(root_dir)

elif stage == "chat":
if chat_interface == "cli":
Expand All @@ -369,7 +255,12 @@ async def run():


def main():
if stage == "chat" and chat_interface == "discord" and os.environ.get("DISCORD_BOT_TOKEN"):
if (
stage == "chat"
and chat_interface == "discord"
and os.environ.get("DISCORD_BOT_TOKEN")
):
client.run(os.environ["DISCORD_BOT_TOKEN"])
else:
asyncio.run(run())
main()
2 changes: 1 addition & 1 deletion pgml-apps/pgml-chat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ packages = [{include = "pgml_chat"}]
python = ">=3.8,<4.0"
openai = "^0.27.8"
rich = "^13.4.2"
pgml = "^0.8.0"
pgml = "^0.9.0"
python-dotenv = "^1.0.0"
click = "^8.1.6"
black = "^23.7.0"
Expand Down