1
1
import asyncio
2
- from pgml import Database
2
+ from pgml import Collection , Model , Splitter , Pipeline
3
3
import logging
4
4
from rich .logging import RichHandler
5
5
from rich .progress import track
@@ -77,26 +77,28 @@ def handler(signum, frame):
77
77
78
78
79
79
# The code is using the `argparse` module to parse command line arguments.
80
- collection_name = args .collection_name
80
+ collection = Collection ( args .collection_name )
81
81
stage = args .stage
82
82
chat_interface = args .chat_interface
83
83
84
84
# The above code is retrieving environment variables and assigning their values to various variables.
85
85
database_url = os .environ .get ("DATABASE_URL" )
86
- db = Database (database_url )
87
- splitter = os .environ .get ("SPLITTER" , "recursive_character" )
86
+ splitter_name = os .environ .get ("SPLITTER" , "recursive_character" )
88
87
splitter_params = os .environ .get (
89
88
"SPLITTER_PARAMS" , {"chunk_size" : 1500 , "chunk_overlap" : 40 }
90
89
)
91
- model = os .environ .get ("MODEL" , "intfloat/e5-small" )
90
+ splitter = Splitter (splitter_name , splitter_params )
91
+ model_name = os .environ .get ("MODEL" , "intfloat/e5-small" )
92
92
model_params = ast .literal_eval (os .environ .get ("MODEL_PARAMS" , {}))
93
+ model = Model (model_name , "pgml" , model_params )
94
+ pipeline = Pipeline (args .collection_name + "_pipeline" , model , splitter )
93
95
query_params = ast .literal_eval (os .environ .get ("QUERY_PARAMS" , {}))
94
96
system_prompt = os .environ .get ("SYSTEM_PROMPT" )
95
97
base_prompt = os .environ .get ("BASE_PROMPT" )
96
98
openai_api_key = os .environ .get ("OPENAI_API_KEY" )
97
99
98
100
99
- async def upsert_documents (db : Database , collection_name : str , folder : str ) -> int :
101
+ async def upsert_documents (folder : str ) -> int :
100
102
log .info ("Scanning " + folder + " for markdown files" )
101
103
md_files = []
102
104
# root_dir needs a trailing slash (i.e. /root/dir/)
@@ -107,100 +109,14 @@ async def upsert_documents(db: Database, collection_name: str, folder: str) -> i
107
109
documents = []
108
110
for md_file in track (md_files , description = "Extracting text from markdown" ):
109
111
with open (md_file , "r" ) as f :
110
- documents .append ({"text" : f .read (), "filename " : md_file })
112
+ documents .append ({"text" : f .read (), "id " : md_file })
111
113
112
114
log .info ("Upserting documents into database" )
113
- collection = await db .create_or_get_collection (collection_name )
114
115
await collection .upsert_documents (documents )
115
116
116
117
return len (md_files )
117
118
118
119
119
- async def generate_chunks (
120
- db : Database ,
121
- collection_name : str ,
122
- splitter : str = "recursive_character" ,
123
- splitter_params : dict = {"chunk_size" : 1500 , "chunk_overlap" : 40 },
124
- ) -> int :
125
- """
126
- The function `generate_chunks` generates chunks for a given collection in a database and returns the
127
- count of chunks created.
128
-
129
- :param db: The `db` parameter is an instance of a database connection or client. It is used to
130
- interact with the database and perform operations such as creating collections, executing queries,
131
- and fetching results
132
- :type db: Database
133
- :param collection_name: The `collection_name` parameter is a string that represents the name of the
134
- collection in the database. It is used to create or get the collection and perform operations on it
135
- :type collection_name: str
136
- :return: The function `generate_chunks` returns an integer, which represents the count of chunks
137
- generated in the specified collection.
138
- """
139
- log .info ("Generating chunks" )
140
- collection = await db .create_or_get_collection (collection_name )
141
- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
142
- query_string = """SELECT count(*) from {collection_name}.chunks""" .format (
143
- collection_name = collection_name
144
- )
145
- results = await db .query (query_string ).fetch_all ()
146
- start_chunks = results [0 ]["count" ]
147
- log .info ("Starting chunk count: " + str (start_chunks ))
148
- await collection .generate_chunks (splitter_id )
149
- results = await db .query (query_string ).fetch_all ()
150
- log .info ("Ending chunk count: " + str (results [0 ]["count" ]))
151
- return results [0 ]["count" ] - start_chunks
152
-
153
-
154
- async def generate_embeddings (
155
- db : Database ,
156
- collection_name : str ,
157
- splitter : str = "recursive_character" ,
158
- splitter_params : dict = {"chunk_size" : 1500 , "chunk_overlap" : 40 },
159
- model : str = "intfloat/e5-small" ,
160
- model_params : dict = {},
161
- ) -> int :
162
- """
163
- The `generate_embeddings` function generates embeddings for text data using a specified model and
164
- splitter.
165
-
166
- :param db: The `db` parameter is an instance of a database object. It is used to interact with the
167
- database and perform operations such as creating or getting a collection, registering a text
168
- splitter, registering a model, and generating embeddings
169
- :type db: Database
170
- :param collection_name: The `collection_name` parameter is a string that represents the name of the
171
- collection in the database where the embeddings will be generated
172
- :type collection_name: str
173
- :param splitter: The `splitter` parameter is used to specify the text splitting method to be used
174
- during the embedding generation process. In this case, the value is set to "recursive_character",
175
- which suggests that the text will be split into chunks based on recursive character splitting,
176
- defaults to recursive_character
177
- :type splitter: str (optional)
178
- :param splitter_params: The `splitter_params` parameter is a dictionary that contains the parameters
179
- for the text splitter. In this case, the `splitter_params` dictionary has two keys:
180
- :type splitter_params: dict
181
- :param model: The `model` parameter is the name or identifier of the language model that will be
182
- used to generate the embeddings. In this case, the model is specified as "intfloat/e5-small",
183
- defaults to intfloat/e5-small
184
- :type model: str (optional)
185
- :param model_params: The `model_params` parameter is a dictionary that allows you to specify
186
- additional parameters for the model. These parameters can be used to customize the behavior of the
187
- model during the embedding generation process. The specific parameters that can be included in the
188
- `model_params` dictionary will depend on the specific model you are
189
- :type model_params: dict
190
- :return: an integer value of 0.
191
- """
192
- log .info ("Generating embeddings" )
193
- collection = await db .create_or_get_collection (collection_name )
194
- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
195
- model_id = await collection .register_model ("embedding" , model , model_params )
196
- log .info ("Splitter ID: " + str (splitter_id ))
197
- start = time ()
198
- await collection .generate_embeddings (model_id , splitter_id )
199
- log .info ("Embeddings generated in %0.3f seconds" % (time () - start ))
200
-
201
- return 0
202
-
203
-
204
120
async def generate_response (
205
121
messages , openai_api_key , temperature = 0.7 , max_tokens = 256 , top_p = 0.9
206
122
):
@@ -217,44 +133,20 @@ async def generate_response(
217
133
return response ["choices" ][0 ]["message" ]["content" ]
218
134
219
135
220
- async def ingest_documents (
221
- db : Database ,
222
- collection_name : str ,
223
- folder : str ,
224
- splitter : str ,
225
- splitter_params : dict ,
226
- model : str ,
227
- model_params : dict ,
228
- ):
229
- total_docs = await upsert_documents (db , collection_name , folder = folder )
230
- total_chunks = await generate_chunks (
231
- db , collection_name , splitter = splitter , splitter_params = splitter_params
232
- )
233
- log .info (
234
- "Total documents: " + str (total_docs ) + " Total chunks: " + str (total_chunks )
235
- )
236
-
237
- await generate_embeddings (
238
- db ,
239
- collection_name ,
240
- splitter = splitter ,
241
- splitter_params = splitter_params ,
242
- model = model ,
243
- model_params = model_params ,
244
- )
136
+ async def ingest_documents (folder : str ):
137
+ # Add the pipeline to the collection, does nothing if we have already added it
138
+ await collection .add_pipeline (pipeline )
139
+ # This will upsert, chunk, and embed the contents in the folder
140
+ total_docs = await upsert_documents (folder )
141
+ log .info ("Total documents: " + str (total_docs ))
245
142
246
143
247
144
async def get_prompt (user_input : str = "" ):
248
- collection = await db .create_or_get_collection (collection_name )
249
- model_id = await collection .register_model ("embedding" , model , model_params )
250
- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
251
- log .info ("Model id: " + str (model_id ) + " Splitter id: " + str (splitter_id ))
252
- vector_results = await collection .vector_search (
253
- user_input ,
254
- model_id = model_id ,
255
- splitter_id = splitter_id ,
256
- top_k = 2 ,
257
- query_params = query_params ,
145
+ vector_results = (
146
+ await collection .query ()
147
+ .vector_recall (user_input , pipeline , query_params )
148
+ .limit (2 )
149
+ .fetch_all ()
258
150
)
259
151
log .info (vector_results )
260
152
context = ""
@@ -322,10 +214,12 @@ async def message_hello(message, say):
322
214
intents .message_content = True
323
215
client = discord .Client (intents = intents )
324
216
217
+
325
218
@client .event
326
219
async def on_ready ():
327
220
print (f"We have logged in as { client .user } " )
328
221
222
+
329
223
@client .event
330
224
async def on_message (message ):
331
225
bot_mention = f"<@{ client .user .id } >"
@@ -351,15 +245,7 @@ async def run():
351
245
352
246
if stage == "ingest" :
353
247
root_dir = args .root_dir
354
- await ingest_documents (
355
- db ,
356
- collection_name ,
357
- root_dir ,
358
- splitter ,
359
- splitter_params ,
360
- model ,
361
- model_params ,
362
- )
248
+ await ingest_documents (root_dir )
363
249
364
250
elif stage == "chat" :
365
251
if chat_interface == "cli" :
@@ -369,7 +255,12 @@ async def run():
369
255
370
256
371
257
def main ():
372
- if stage == "chat" and chat_interface == "discord" and os .environ .get ("DISCORD_BOT_TOKEN" ):
258
+ if (
259
+ stage == "chat"
260
+ and chat_interface == "discord"
261
+ and os .environ .get ("DISCORD_BOT_TOKEN" )
262
+ ):
373
263
client .run (os .environ ["DISCORD_BOT_TOKEN" ])
374
264
else :
375
265
asyncio .run (run ())
266
+ main ()
0 commit comments