Skip to content

Commit c80817b

Browse files
committed
Finalized models in SDK for open source ai
1 parent 95e1e9a commit c80817b

File tree

1 file changed

+104
-13
lines changed

1 file changed

+104
-13
lines changed

pgml-sdks/pgml/src/open_source_ai.rs

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,130 @@ fn try_model_nice_name_to_model_name_and_parameters(
2222
model_name: &str,
2323
) -> Option<(&'static str, Json)> {
2424
match model_name {
25-
"mistralai/Mistral-7B-v0.1" => Some((
26-
"TheBloke/zephyr-7B-beta-GPTQ",
25+
// Not all models will necessarily have the same parameters / naming relation but they happen to now
26+
"mistralai/Mistral-7B-Instruct-v0.1" => Some((
27+
"mistralai/Mistral-7B-Instruct-v0.1",
2728
serde_json::json!({
2829
"task": "conversational",
29-
"model": "TheBloke/zephyr-7B-beta-GPTQ",
30+
"model": "mistralai/Mistral-7B-Instruct-v0.1",
3031
"device_map": "auto",
31-
"revision": "main",
32-
"model_type": "mistral"
32+
"torch_dtype": "bfloat16"
3333
})
3434
.into(),
3535
)),
36-
"meta-llama/Llama-2-7b-chat-hf" => Some((
36+
37+
"HuggingFaceH4/zephyr-7b-beta" => Some((
38+
"HuggingFaceH4/zephyr-7b-beta",
39+
serde_json::json!({
40+
"task": "conversational",
41+
"model": "HuggingFaceH4/zephyr-7b-beta",
42+
"device_map": "auto",
43+
"torch_dtype": "bfloat16"
44+
})
45+
.into(),
46+
)),
47+
48+
"TheBloke/Llama-2-7B-Chat-GPTQ" => Some((
3749
"TheBloke/Llama-2-7B-Chat-GPTQ",
3850
serde_json::json!({
3951
"task": "conversational",
40-
"model": "TheBloke/zephyr-7B-beta-GPTQ",
52+
"model": "TheBloke/Llama-2-7B-Chat-GPTQ",
53+
"device_map": "auto",
54+
"revision": "main"
55+
})
56+
.into(),
57+
)),
58+
59+
"teknium/OpenHermes-2.5-Mistral-7B" => Some((
60+
"teknium/OpenHermes-2.5-Mistral-7B",
61+
serde_json::json!({
62+
"task": "conversational",
63+
"model": "teknium/OpenHermes-2.5-Mistral-7B",
64+
"device_map": "auto",
65+
"torch_dtype": "bfloat16"
66+
})
67+
.into(),
68+
)),
69+
70+
"Open-Orca/Mistral-7B-OpenOrca" => Some((
71+
"Open-Orca/Mistral-7B-OpenOrca",
72+
serde_json::json!({
73+
"task": "conversational",
74+
"model": "Open-Orca/Mistral-7B-OpenOrca",
4175
"device_map": "auto",
42-
"revision": "main",
43-
"model_type": "llama"
76+
"torch_dtype": "bfloat16"
77+
})
78+
.into(),
79+
)),
80+
81+
"Undi95/Toppy-M-7B" => Some((
82+
"Undi95/Toppy-M-7B",
83+
serde_json::json!({
84+
"model": "Undi95/Toppy-M-7B",
85+
"device_map": "auto",
86+
"torch_dtype": "bfloat16"
4487
})
4588
.into(),
4689
)),
90+
91+
"Undi95/ReMM-SLERP-L2-13B" => Some((
92+
"Undi95/ReMM-SLERP-L2-13B",
93+
serde_json::json!({
94+
"model": "Undi95/ReMM-SLERP-L2-13B",
95+
"device_map": "auto",
96+
"torch_dtype": "bfloat16"
97+
})
98+
.into(),
99+
)),
100+
101+
"Gryphe/MythoMax-L2-13b" => Some((
102+
"Gryphe/MythoMax-L2-13b",
103+
serde_json::json!({
104+
"model": "Gryphe/MythoMax-L2-13b",
105+
"device_map": "auto",
106+
"torch_dtype": "bfloat16"
107+
})
108+
.into(),
109+
)),
110+
47111
"PygmalionAI/mythalion-13b" => Some((
48-
"TheBloke/Mythalion-13B-GPTQ",
112+
"PygmalionAI/mythalion-13b",
113+
serde_json::json!({
114+
"model": "PygmalionAI/mythalion-13b",
115+
"device_map": "auto",
116+
"torch_dtype": "bfloat16"
117+
})
118+
.into(),
119+
)),
120+
121+
"deepseek-ai/deepseek-llm-7b-chat" => Some((
122+
"deepseek-ai/deepseek-llm-7b-chat",
123+
serde_json::json!({
124+
"model": "deepseek-ai/deepseek-llm-7b-chat",
125+
"device_map": "auto",
126+
"torch_dtype": "bfloat16"
127+
})
128+
.into(),
129+
)),
130+
131+
"Phind/Phind-CodeLlama-34B-v2" => Some((
132+
"Phind/Phind-CodeLlama-34B-v2",
49133
serde_json::json!({
50-
"model": "TheBloke/Mythalion-13B-GPTQ",
134+
"model": "Phind/Phind-CodeLlama-34B-v2",
51135
"device_map": "auto",
52-
"revision": "main"
136+
"torch_dtype": "bfloat16"
53137
})
54138
.into(),
55139
)),
140+
56141
_ => None,
57142
}
58143
}
59144

60145
fn try_get_model_chat_template(model_name: &str) -> Option<&'static str> {
61146
match model_name {
147+
// Any Alpaca instruct tuned model
148+
"Undi95/Toppy-M-7B" | "Undi95/ReMM-SLERP-L2-13B" | "Gryphe/MythoMax-L2-13b" | "Phind/Phind-CodeLlama-34B-v2" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### Instruction:\n' + message['content'] + '\n'}}\n{% elif message['role'] == 'system' %}\n{{ message['content'] + '\n'}}\n{% elif message['role'] == 'model' %}\n{{ '### Response:>\n' + message['content'] + eos_token + '\n'}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Response:' }}\n{% endif %}\n{% endfor %}"),
62149
"PygmalionAI/mythalion-13b" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'model' %}\n{{ '<|model|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|model|>' }}\n{% endif %}\n{% endfor %}"),
63150
_ => None
64151
}
@@ -130,6 +217,7 @@ mistralai/Mistral-7B-v0.1
130217
}
131218
}
132219

220+
#[allow(clippy::too_many_arguments)]
133221
pub async fn chat_completions_create_stream_async(
134222
&self,
135223
model: Json,
@@ -191,14 +279,15 @@ mistralai/Mistral-7B-v0.1
191279
Ok(GeneralJsonAsyncIterator(Box::pin(iter)))
192280
}
193281

282+
#[allow(clippy::too_many_arguments)]
194283
pub fn chat_completions_create_stream(
195284
&self,
196285
model: Json,
197286
messages: Vec<Json>,
198287
max_tokens: Option<i32>,
199288
temperature: Option<f64>,
200-
chat_template: Option<String>,
201289
n: Option<i32>,
290+
chat_template: Option<String>,
202291
) -> anyhow::Result<GeneralJsonIterator> {
203292
let runtime = crate::get_or_set_runtime();
204293
let iter = runtime.block_on(self.chat_completions_create_stream_async(
@@ -214,6 +303,7 @@ mistralai/Mistral-7B-v0.1
214303
))))
215304
}
216305

306+
#[allow(clippy::too_many_arguments)]
217307
pub async fn chat_completions_create_async(
218308
&self,
219309
model: Json,
@@ -282,6 +372,7 @@ mistralai/Mistral-7B-v0.1
282372
.into())
283373
}
284374

375+
#[allow(clippy::too_many_arguments)]
285376
pub fn chat_completions_create(
286377
&self,
287378
model: Json,

0 commit comments

Comments
 (0)