@@ -49,21 +49,17 @@ impl OpenSourceAI {
49
49
Self { database_url }
50
50
}
51
51
52
- pub async fn chat_completions_create_async (
52
+ fn create_pipeline_model_name_parameters (
53
53
& self ,
54
54
mut model : Json ,
55
- messages : Json ,
56
- max_tokens : Option < i32 > ,
57
- temperature : Option < f64 > ,
58
- n : Option < i32 > ,
59
- ) -> anyhow:: Result < Json > {
60
- let ( transformer_pipeline, model_name, model_parameters) = if model. is_object ( ) {
55
+ ) -> anyhow:: Result < ( TransformerPipeline , String , Json ) > {
56
+ if model. is_object ( ) {
61
57
let args = model. as_object_mut ( ) . unwrap ( ) ;
62
58
let model_name = args
63
59
. remove ( "model" )
64
60
. context ( "`model` is a required key in the model object" ) ?;
65
61
let model_name = model_name. as_str ( ) . context ( "`model` must be a string" ) ?;
66
- (
62
+ Ok ( (
67
63
TransformerPipeline :: new (
68
64
"conversational" ,
69
65
Some ( model_name. to_string ( ) ) ,
@@ -72,7 +68,7 @@ impl OpenSourceAI {
72
68
) ,
73
69
model_name. to_string ( ) ,
74
70
model,
75
- )
71
+ ) )
76
72
} else {
77
73
let model_name = model
78
74
. as_str ( )
@@ -83,7 +79,7 @@ impl OpenSourceAI {
83
79
mistralai/Mistral-7B-v0.1
84
80
"# ,
85
81
) ?;
86
- (
82
+ Ok ( (
87
83
TransformerPipeline :: new (
88
84
"conversational" ,
89
85
Some ( real_model_name. to_string ( ) ) ,
@@ -92,36 +88,52 @@ mistralai/Mistral-7B-v0.1
92
88
) ,
93
89
model_name. to_string ( ) ,
94
90
parameters,
95
- )
96
- } ;
91
+ ) )
92
+ }
93
+ }
94
+
95
+ pub async fn chat_completions_create_stream_async (
96
+ & self ,
97
+ model : Json ,
98
+ messages : Vec < Json > ,
99
+ max_tokens : Option < i32 > ,
100
+ temperature : Option < f64 > ,
101
+ n : Option < i32 > ,
102
+ ) -> anyhow:: Result < ( ) > {
103
+ Ok ( ( ) )
104
+ }
105
+
106
+ pub async fn chat_completions_create_async (
107
+ & self ,
108
+ model : Json ,
109
+ messages : Vec < Json > ,
110
+ max_tokens : Option < i32 > ,
111
+ temperature : Option < f64 > ,
112
+ n : Option < i32 > ,
113
+ ) -> anyhow:: Result < Json > {
114
+ let ( transformer_pipeline, model_name, model_parameters) =
115
+ self . create_pipeline_model_name_parameters ( model) ?;
97
116
98
117
let max_tokens = max_tokens. unwrap_or ( 1000 ) ;
99
118
let temperature = temperature. unwrap_or ( 0.8 ) ;
100
119
let n = n. unwrap_or ( 1 ) as usize ;
101
- let to_hash = format ! (
102
- "{}{}{}{}" ,
103
- model_parameters. to_string( ) ,
104
- max_tokens,
105
- temperature,
106
- n
107
- ) ;
120
+ let to_hash = format ! ( "{}{}{}{}" , * model_parameters, max_tokens, temperature, n) ;
108
121
let md5_digest = md5:: compute ( to_hash. as_bytes ( ) ) ;
109
122
let fingerprint = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
110
123
111
- let messages: Vec < Json > = std:: iter:: repeat ( messages) . take ( n) . collect ( ) ;
112
124
let choices = transformer_pipeline
113
125
. transform (
114
126
messages,
115
127
Some (
116
- serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature } )
128
+ serde_json:: json!( { "max_length" : max_tokens, "temperature" : temperature, "do_sample" : true , "num_return_sequences" : n } )
117
129
. into ( ) ,
118
130
) ,
119
131
)
120
132
. await ?;
121
133
let choices: Vec < Json > = choices
122
134
. as_array ( )
123
135
. context ( "Error parsing return from TransformerPipeline" ) ?
124
- . into_iter ( )
136
+ . iter ( )
125
137
. enumerate ( )
126
138
. map ( |( i, c) | {
127
139
serde_json:: json!( {
@@ -157,7 +169,7 @@ mistralai/Mistral-7B-v0.1
157
169
pub fn chat_completions_create (
158
170
& self ,
159
171
model : Json ,
160
- messages : Json ,
172
+ messages : Vec < Json > ,
161
173
max_tokens : Option < i32 > ,
162
174
temperature : Option < f64 > ,
163
175
n : Option < i32 > ,
@@ -177,14 +189,14 @@ mistralai/Mistral-7B-v0.1
177
189
mod tests {
178
190
use super :: * ;
179
191
180
- #[ sqlx :: test]
181
- async fn can_open_source_ai_create ( ) -> anyhow:: Result < ( ) > {
192
+ #[ test]
193
+ fn can_open_source_ai_create ( ) -> anyhow:: Result < ( ) > {
182
194
let client = OpenSourceAI :: new ( None ) ;
183
- let results = client. chat_completions_create ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , serde_json :: json! ( [
184
- { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ,
185
- { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" }
186
- ] ) . into ( ) , Some ( 1000 ) , None , None ) ?;
187
- assert ! ( results. as_array( ) . is_some( ) ) ;
195
+ let results = client. chat_completions_create ( Json :: from_serializable ( "mistralai/Mistral-7B-v0.1" ) , vec ! [
196
+ serde_json :: json! ( { "role" : "system" , "content" : "You are a friendly chatbot who always responds in the style of a pirate" } ) . into ( ) ,
197
+ serde_json :: json! ( { "role" : "user" , "content" : "How many helicopters can a human eat in one sitting?" } ) . into ( ) ,
198
+ ] , Some ( 1000 ) , None , Some ( 3 ) ) ?;
199
+ assert ! ( results[ "choices" ] . as_array( ) . is_some( ) ) ;
188
200
Ok ( ( ) )
189
201
}
190
202
}
0 commit comments