1
1
use anyhow:: Context ;
2
- use futures:: { StreamExt , Stream } ;
2
+ use futures:: { Stream , StreamExt } ;
3
3
use rust_bridge:: { alias, alias_methods} ;
4
4
use std:: time:: { SystemTime , UNIX_EPOCH } ;
5
5
use uuid:: Uuid ;
6
6
7
7
use crate :: {
8
- types:: { GeneralJsonAsyncIterator , Json , GeneralJsonIterator } ,
9
- TransformerPipeline , get_or_set_runtime,
8
+ get_or_set_runtime,
9
+ types:: { GeneralJsonAsyncIterator , GeneralJsonIterator , Json } ,
10
+ TransformerPipeline ,
10
11
} ;
11
12
12
13
#[ cfg( feature = "python" ) ]
13
- use crate :: types:: { JsonPython , GeneralJsonAsyncIteratorPython , GeneralJsonIteratorPython } ;
14
+ use crate :: types:: { GeneralJsonAsyncIteratorPython , GeneralJsonIteratorPython , JsonPython } ;
14
15
15
16
#[ derive( alias, Debug , Clone ) ]
16
17
pub struct OpenSourceAI {
@@ -43,10 +44,26 @@ fn try_model_nice_name_to_model_name_and_parameters(
43
44
} )
44
45
. into ( ) ,
45
46
) ) ,
47
+ "PygmalionAI/mythalion-13b" => Some ( (
48
+ "TheBloke/Mythalion-13B-GPTQ" ,
49
+ serde_json:: json!( {
50
+ "model" : "TheBloke/Mythalion-13B-GPTQ" ,
51
+ "device_map" : "auto" ,
52
+ "revision" : "main"
53
+ } )
54
+ . into ( ) ,
55
+ ) ) ,
46
56
_ => None ,
47
57
}
48
58
}
49
59
60
+ fn try_get_model_chat_template ( model_name : & str ) -> Option < & ' static str > {
61
+ match model_name {
62
+ "PygmalionAI/mythalion-13b" => Some ( "{% for message in messages %}\n {% if message['role'] == 'user' %}\n {{ '<|user|>\n ' + message['content'] + eos_token }}\n {% elif message['role'] == 'system' %}\n {{ '<|system|>\n ' + message['content'] + eos_token }}\n {% elif message['role'] == 'model' %}\n {{ '<|model|>\n ' + message['content'] + eos_token }}\n {% endif %}\n {% if loop.last and add_generation_prompt %}\n {{ '<|model|>' }}\n {% endif %}\n {% endfor %}" ) ,
63
+ _ => None
64
+ }
65
+ }
66
+
50
67
struct AsyncToSyncJsonIterator ( std:: pin:: Pin < Box < dyn Stream < Item = anyhow:: Result < Json > > + Send > > ) ;
51
68
52
69
impl Iterator for AsyncToSyncJsonIterator {
@@ -58,7 +75,13 @@ impl Iterator for AsyncToSyncJsonIterator {
58
75
}
59
76
}
60
77
61
- #[ alias_methods( new, chat_completions_create, chat_completions_create_async, chat_completions_create_stream, chat_completions_create_stream_async) ]
78
+ #[ alias_methods(
79
+ new,
80
+ chat_completions_create,
81
+ chat_completions_create_async,
82
+ chat_completions_create_stream,
83
+ chat_completions_create_stream_async
84
+ ) ]
62
85
impl OpenSourceAI {
63
86
pub fn new ( database_url : Option < String > ) -> Self {
64
87
Self { database_url }
@@ -114,6 +137,7 @@ mistralai/Mistral-7B-v0.1
114
137
max_tokens : Option < i32 > ,
115
138
temperature : Option < f64 > ,
116
139
n : Option < i32 > ,
140
+ chat_template : Option < String > ,
117
141
) -> anyhow:: Result < GeneralJsonAsyncIterator > {
118
142
let ( transformer_pipeline, model_name, model_parameters) =
119
143
self . create_pipeline_model_name_parameters ( model) ?;
@@ -125,24 +149,26 @@ mistralai/Mistral-7B-v0.1
125
149
let md5_digest = md5:: compute ( to_hash. as_bytes ( ) ) ;
126
150
let fingerprint = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
127
151
152
+ let mut args = serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature, "do_sample" : true , "num_return_sequences" : n } ) ;
153
+ if let Some ( t) = chat_template
154
+ . or_else ( || try_get_model_chat_template ( & model_name) . map ( |s| s. to_string ( ) ) )
155
+ {
156
+ args. as_object_mut ( ) . unwrap ( ) . insert (
157
+ "chat_template" . to_string ( ) ,
158
+ serde_json:: to_value ( t) . unwrap ( ) ,
159
+ ) ;
160
+ }
161
+
128
162
let messages = serde_json:: to_value ( messages) ?. into ( ) ;
129
163
let iterator = transformer_pipeline
130
- . transform_stream (
131
- messages,
132
- Some (
133
- serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature, "do_sample" : true , "num_return_sequences" : n } )
134
- . into ( ) ,
135
- ) ,
136
- Some ( 1 )
137
- )
164
+ . transform_stream ( messages, Some ( args. into ( ) ) , Some ( 1 ) )
138
165
. await ?;
139
166
140
167
let id = Uuid :: new_v4 ( ) . to_string ( ) ;
141
168
let iter = iterator. map ( move |choices| {
142
169
let since_the_epoch = SystemTime :: now ( )
143
170
. duration_since ( UNIX_EPOCH )
144
171
. expect ( "Time went backwards" ) ;
145
- eprintln ! ( "{:?}" , choices) ;
146
172
Ok ( serde_json:: json!( {
147
173
"id" : id. clone( ) ,
148
174
"system_fingerprint" : fingerprint. clone( ) ,
@@ -155,9 +181,8 @@ mistralai/Mistral-7B-v0.1
155
181
"delta" : {
156
182
"role" : "assistant" ,
157
183
"content" : c
158
- }
184
+ }
159
185
} )
160
- // finish_reason goes here
161
186
} ) . collect:: <serde_json:: Value >( )
162
187
} )
163
188
. into ( ) )
@@ -172,11 +197,21 @@ mistralai/Mistral-7B-v0.1
172
197
messages : Vec < Json > ,
173
198
max_tokens : Option < i32 > ,
174
199
temperature : Option < f64 > ,
200
+ chat_template : Option < String > ,
175
201
n : Option < i32 > ,
176
202
) -> anyhow:: Result < GeneralJsonIterator > {
177
203
let runtime = crate :: get_or_set_runtime ( ) ;
178
- let iter = runtime. block_on ( self . chat_completions_create_stream_async ( model, messages, max_tokens, temperature, n) ) ?;
179
- Ok ( GeneralJsonIterator ( Box :: new ( AsyncToSyncJsonIterator ( Box :: pin ( iter) ) ) ) )
204
+ let iter = runtime. block_on ( self . chat_completions_create_stream_async (
205
+ model,
206
+ messages,
207
+ max_tokens,
208
+ temperature,
209
+ n,
210
+ chat_template,
211
+ ) ) ?;
212
+ Ok ( GeneralJsonIterator ( Box :: new ( AsyncToSyncJsonIterator (
213
+ Box :: pin ( iter) ,
214
+ ) ) ) )
180
215
}
181
216
182
217
pub async fn chat_completions_create_async (
@@ -186,6 +221,7 @@ mistralai/Mistral-7B-v0.1
186
221
max_tokens : Option < i32 > ,
187
222
temperature : Option < f64 > ,
188
223
n : Option < i32 > ,
224
+ chat_template : Option < String > ,
189
225
) -> anyhow:: Result < Json > {
190
226
let ( transformer_pipeline, model_name, model_parameters) =
191
227
self . create_pipeline_model_name_parameters ( model) ?;
@@ -197,14 +233,18 @@ mistralai/Mistral-7B-v0.1
197
233
let md5_digest = md5:: compute ( to_hash. as_bytes ( ) ) ;
198
234
let fingerprint = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
199
235
236
+ let mut args = serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature, "do_sample" : true , "num_return_sequences" : n } ) ;
237
+ if let Some ( t) = chat_template
238
+ . or_else ( || try_get_model_chat_template ( & model_name) . map ( |s| s. to_string ( ) ) )
239
+ {
240
+ args. as_object_mut ( ) . unwrap ( ) . insert (
241
+ "chat_template" . to_string ( ) ,
242
+ serde_json:: to_value ( t) . unwrap ( ) ,
243
+ ) ;
244
+ }
245
+
200
246
let choices = transformer_pipeline
201
- . transform (
202
- messages,
203
- Some (
204
- serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature, "do_sample" : true , "num_return_sequences" : n } )
205
- . into ( ) ,
206
- ) ,
207
- )
247
+ . transform ( messages, Some ( args. into ( ) ) )
208
248
. await ?;
209
249
let choices: Vec < Json > = choices
210
250
. as_array ( )
@@ -249,6 +289,7 @@ mistralai/Mistral-7B-v0.1
249
289
max_tokens : Option < i32 > ,
250
290
temperature : Option < f64 > ,
251
291
n : Option < i32 > ,
292
+ chat_template : Option < String > ,
252
293
) -> anyhow:: Result < Json > {
253
294
let runtime = crate :: get_or_set_runtime ( ) ;
254
295
runtime. block_on ( self . chat_completions_create_async (
@@ -257,6 +298,7 @@ mistralai/Mistral-7B-v0.1
257
298
max_tokens,
258
299
temperature,
259
300
n,
301
+ chat_template,
260
302
) )
261
303
}
262
304
}
@@ -272,7 +314,7 @@ mod tests {
272
314
let results = client. chat_completions_create ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , vec ! [
273
315
serde_json:: json!( { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ) . into( ) ,
274
316
serde_json:: json!( { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" } ) . into( ) ,
275
- ] , Some ( 10 ) , None , Some ( 3 ) ) ?;
317
+ ] , Some ( 10 ) , None , Some ( 3 ) , None ) ?;
276
318
assert ! ( results[ "choices" ] . as_array( ) . is_some( ) ) ;
277
319
Ok ( ( ) )
278
320
}
@@ -283,7 +325,7 @@ mod tests {
283
325
let results = client. chat_completions_create_async ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , vec ! [
284
326
serde_json:: json!( { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ) . into( ) ,
285
327
serde_json:: json!( { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" } ) . into( ) ,
286
- ] , Some ( 10 ) , None , Some ( 3 ) ) . await ?;
328
+ ] , Some ( 10 ) , None , Some ( 3 ) , None ) . await ?;
287
329
assert ! ( results[ "choices" ] . as_array( ) . is_some( ) ) ;
288
330
Ok ( ( ) )
289
331
}
@@ -294,7 +336,7 @@ mod tests {
294
336
let mut stream = client. chat_completions_create_stream_async ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , vec ! [
295
337
serde_json:: json!( { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ) . into( ) ,
296
338
serde_json:: json!( { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" } ) . into( ) ,
297
- ] , Some ( 10 ) , None , Some ( 3 ) ) . await ?;
339
+ ] , Some ( 10 ) , None , Some ( 3 ) , None ) . await ?;
298
340
while let Some ( o) = stream. next ( ) . await {
299
341
o?;
300
342
}
@@ -307,11 +349,10 @@ mod tests {
307
349
let iterator = client. chat_completions_create_stream ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , vec ! [
308
350
serde_json:: json!( { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ) . into( ) ,
309
351
serde_json:: json!( { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" } ) . into( ) ,
310
- ] , Some ( 10 ) , None , Some ( 3 ) ) ?;
352
+ ] , Some ( 10 ) , None , Some ( 3 ) , None ) ?;
311
353
for o in iterator {
312
354
o?;
313
355
}
314
356
Ok ( ( ) )
315
357
}
316
-
317
358
}
0 commit comments