diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index f28f4db10..121b52b3b 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -854,6 +854,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "data-encoding" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" + [[package]] name = "debugid" version = "0.8.0" @@ -1229,9 +1235,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1244,9 +1250,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1254,15 +1260,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1282,15 +1288,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1299,21 +1305,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -2515,6 +2521,7 @@ dependencies = [ "csv-async", "dotenv", "env_logger", + "futures", "glob", "itertools", "lazy_static", @@ -2530,6 +2537,7 @@ dependencies = [ "regex", "reqwest", "rocket", + "rocket_ws", "sailfish", "scraper", "sentry", @@ -3041,8 +3049,8 @@ dependencies = [ [[package]] name = "rocket" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "async-stream", "async-trait", @@ -3078,8 +3086,8 @@ dependencies = [ [[package]] name = "rocket_codegen" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "devise", "glob", @@ -3094,8 +3102,8 @@ dependencies = [ [[package]] name = "rocket_http" -version = "0.5.0-rc.3" -source = "git+https://github.com/SergioBenitez/Rocket#07fe79796f058ab12683ff9e344558bece263274" +version = "0.6.0-dev" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" dependencies = [ "cookie", "either", @@ -3118,6 +3126,15 @@ dependencies = [ "uncased", ] +[[package]] +name = "rocket_ws" +version = "0.1.0" +source = "git+https://github.com/SergioBenitez/Rocket#7f7d352e453e83f3d23ee12f8965ce75c977fcea" +dependencies = [ + "rocket", + "tokio-tungstenite", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -4337,6 +4354,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.8" @@ -4525,6 +4554,25 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typed-arena" version = "2.0.2" diff --git a/pgml-dashboard/Cargo.toml b/pgml-dashboard/Cargo.toml index 47238f6ed..6d1b803dd 100644 --- a/pgml-dashboard/Cargo.toml +++ b/pgml-dashboard/Cargo.toml @@ -50,3 +50,5 @@ tokio = { version = "1", features = ["full"] } url = "2.4" yaml-rust = "0.4" zoomies = { git="https://github.com/HyperparamAI/zoomies.git", branch="master" } +ws = { package = "rocket_ws", git = "https://github.com/SergioBenitez/Rocket" } +futures = "0.3.29" diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index c4b12d0c2..0b8978844 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -1,9 +1,10 @@ use anyhow::Context; -use pgml::{Collection, Pipeline}; +use futures::stream::StreamExt; +use pgml::{types::GeneralJsonAsyncIterator, Collection, OpenSourceAI, Pipeline}; use rand::{distributions::Alphanumeric, Rng}; use reqwest::Client; use rocket::{ - http::Status, + http::{Cookie, CookieJar, Status}, outcome::IntoOutcome, request::{self, FromRequest}, route::Route, @@ -14,11 +15,6 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::time::{SystemTime, UNIX_EPOCH}; -use crate::{ - forms, - responses::{Error, ResponseOk}, -}; - pub struct User { chatbot_session_id: String, } @@ -40,32 +36,134 @@ impl<'r> FromRequest<'r> for User { #[derive(Serialize, Deserialize, PartialEq, Eq)] enum ChatRole { + System, User, Bot, } +impl ChatRole { + fn to_model_specific_role(&self, brain: &ChatbotBrain) -> &'static str { + match self { + ChatRole::User => "user", + ChatRole::Bot => match brain { + ChatbotBrain::OpenAIGPT4 + | ChatbotBrain::TekniumOpenHermes25Mistral7B + | ChatbotBrain::Starling7b => "assistant", + ChatbotBrain::GrypheMythoMaxL213b => "model", + }, + ChatRole::System => "system", + } + } +} + #[derive(Clone, Copy, Serialize, Deserialize)] enum ChatbotBrain { OpenAIGPT4, - PostgresMLFalcon180b, - AnthropicClaude, - MetaLlama2, + TekniumOpenHermes25Mistral7B, + GrypheMythoMaxL213b, + Starling7b, +} + +impl ChatbotBrain { + fn is_open_source(&self) -> bool { + !matches!(self, Self::OpenAIGPT4) + } + + fn get_system_message( + &self, + knowledge_base: &KnowledgeBase, + context: &str, + ) -> anyhow::Result { + match self { + Self::OpenAIGPT4 => { + let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?; + let system_prompt = system_prompt + .replace("{topic}", knowledge_base.topic()) + .replace("{persona}", "Engineer") + .replace("{language}", "English"); + Ok(serde_json::json!({ + "role": "system", + "content": system_prompt + })) + } + _ => Ok(serde_json::json!({ + "role": "system", + "content": format!(r#"You are a friendly and helpful chatbot that uses the following documents to answer the user's questions with the best of your ability. There is one rule: Do Not Lie. + +{} + + "#, context) + })), + } + } + + fn into_model_json(self) -> serde_json::Value { + match self { + Self::TekniumOpenHermes25Mistral7B => serde_json::json!({ + "model": "TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + Self::GrypheMythoMaxL213b => serde_json::json!({ + "model": "TheBloke/MythoMax-L2-13B-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + Self::Starling7b => serde_json::json!({ + "model": "TheBloke/Starling-LM-7B-alpha-GPTQ", + "revision": "main", + "device_map": "auto", + "quantization_config": { + "bits": 4, + "max_input_length": 10000 + } + }), + _ => unimplemented!(), + } + } + + fn get_chat_template(&self) -> Option<&'static str> { + match self { + Self::TekniumOpenHermes25Mistral7B => Some("{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"), + Self::GrypheMythoMaxL213b => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### Instruction:\n' + message['content'] + '\n'}}\n{% elif message['role'] == 'system' %}\n{{ message['content'] + '\n'}}\n{% elif message['role'] == 'model' %}\n{{ '### Response:>\n' + message['content'] + eos_token + '\n'}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Response:' }}\n{% endif %}\n{% endfor %}"), + _ => None + } + } } -impl TryFrom for ChatbotBrain { +impl TryFrom<&str> for ChatbotBrain { type Error = anyhow::Error; - fn try_from(value: u8) -> anyhow::Result { + fn try_from(value: &str) -> anyhow::Result { match value { - 0 => Ok(ChatbotBrain::OpenAIGPT4), - 1 => Ok(ChatbotBrain::PostgresMLFalcon180b), - 2 => Ok(ChatbotBrain::AnthropicClaude), - 3 => Ok(ChatbotBrain::MetaLlama2), + "teknium/OpenHermes-2.5-Mistral-7B" => Ok(ChatbotBrain::TekniumOpenHermes25Mistral7B), + "Gryphe/MythoMax-L2-13b" => Ok(ChatbotBrain::GrypheMythoMaxL213b), + "openai" => Ok(ChatbotBrain::OpenAIGPT4), + "berkeley-nest/Starling-LM-7B-alpha" => Ok(ChatbotBrain::Starling7b), _ => Err(anyhow::anyhow!("Invalid brain id")), } } } +impl From for &'static str { + fn from(value: ChatbotBrain) -> Self { + match value { + ChatbotBrain::TekniumOpenHermes25Mistral7B => "teknium/OpenHermes-2.5-Mistral-7B", + ChatbotBrain::GrypheMythoMaxL213b => "Gryphe/MythoMax-L2-13b", + ChatbotBrain::OpenAIGPT4 => "openai", + ChatbotBrain::Starling7b => "berkeley-nest/Starling-LM-7B-alpha", + } + } +} + #[derive(Clone, Copy, Serialize, Deserialize)] enum KnowledgeBase { PostgresML, @@ -95,20 +193,31 @@ impl KnowledgeBase { } } -impl TryFrom for KnowledgeBase { +impl TryFrom<&str> for KnowledgeBase { type Error = anyhow::Error; - fn try_from(value: u8) -> anyhow::Result { + fn try_from(value: &str) -> anyhow::Result { match value { - 0 => Ok(KnowledgeBase::PostgresML), - 1 => Ok(KnowledgeBase::PyTorch), - 2 => Ok(KnowledgeBase::Rust), - 3 => Ok(KnowledgeBase::PostgreSQL), + "postgresml" => Ok(KnowledgeBase::PostgresML), + "pytorch" => Ok(KnowledgeBase::PyTorch), + "rust" => Ok(KnowledgeBase::Rust), + "postgresql" => Ok(KnowledgeBase::PostgreSQL), _ => Err(anyhow::anyhow!("Invalid knowledge base id")), } } } +impl From for &'static str { + fn from(value: KnowledgeBase) -> Self { + match value { + KnowledgeBase::PostgresML => "postgresml", + KnowledgeBase::PyTorch => "pytorch", + KnowledgeBase::Rust => "rust", + KnowledgeBase::PostgreSQL => "postgresql", + } + } +} + #[derive(Serialize, Deserialize)] struct Document { id: String, @@ -122,7 +231,7 @@ struct Document { impl Document { fn new( - text: String, + text: &str, role: ChatRole, user_id: String, model: ChatbotBrain, @@ -139,7 +248,7 @@ impl Document { .as_millis(); Document { id, - text, + text: text.to_string(), role, user_id, model, @@ -149,29 +258,11 @@ impl Document { } } -async fn get_openai_chatgpt_answer( - knowledge_base: KnowledgeBase, - history: &str, - context: &str, - question: &str, -) -> Result { +async fn get_openai_chatgpt_answer(messages: M) -> anyhow::Result { let openai_api_key = std::env::var("OPENAI_API_KEY")?; - let base_prompt = std::env::var("CHATBOT_CHATGPT_BASE_PROMPT")?; - let system_prompt = std::env::var("CHATBOT_CHATGPT_SYSTEM_PROMPT")?; - - let system_prompt = system_prompt - .replace("{topic}", knowledge_base.topic()) - .replace("{persona}", "Engineer") - .replace("{language}", "English"); - - let content = base_prompt - .replace("{history}", history) - .replace("{context}", context) - .replace("{question}", question); - let body = json!({ "model": "gpt-3.5-turbo", - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": content}], + "messages": messages, "temperature": 0.7 }); @@ -194,60 +285,133 @@ async fn get_openai_chatgpt_answer( Ok(response) } -#[post("/chatbot/get-answer", format = "json", data = "")] -pub async fn chatbot_get_answer( - user: User, - data: Json, -) -> Result { - match wrapped_chatbot_get_answer(user, data).await { - Ok(response) => Ok(ResponseOk( - json!({ - "answer": response, - }) - .to_string(), - )), - Err(error) => { - eprintln!("Error: {:?}", error); - Ok(ResponseOk( - json!({ - "error": error.to_string(), - }) - .to_string(), - )) +struct UpdateHistory { + collection: Collection, + user_document: Document, + model: ChatbotBrain, + knowledge_base: KnowledgeBase, +} + +impl UpdateHistory { + fn new( + collection: Collection, + user_document: Document, + model: ChatbotBrain, + knowledge_base: KnowledgeBase, + ) -> Self { + Self { + collection, + user_document, + model, + knowledge_base, } } + + fn update_history(mut self, chatbot_response: &str) -> anyhow::Result<()> { + let chatbot_document = Document::new( + chatbot_response, + ChatRole::Bot, + self.user_document.user_id.to_owned(), + self.model, + self.knowledge_base, + ); + let new_history_messages: Vec = vec![ + serde_json::to_value(self.user_document).unwrap().into(), + serde_json::to_value(chatbot_document).unwrap().into(), + ]; + // We do not want to block our return waiting for this to happen + tokio::spawn(async move { + self.collection + .upsert_documents(new_history_messages, None) + .await + .expect("Failed to upsert user history"); + }); + Ok(()) + } } -pub async fn wrapped_chatbot_get_answer( - user: User, - data: Json, -) -> Result { - let brain = ChatbotBrain::try_from(data.model)?; - let knowledge_base = KnowledgeBase::try_from(data.knowledge_base)?; - - // Create it up here so the timestamps that order the conversation are accurate - let user_document = Document::new( - data.question.clone(), - ChatRole::User, - user.chatbot_session_id.clone(), - brain, - knowledge_base, - ); +#[derive(Serialize)] +struct StreamResponse { + id: Option, + error: Option, + result: Option, + partial_result: Option, +} - let collection = knowledge_base.collection(); - let collection = Collection::new( - collection, - Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); +impl StreamResponse { + fn from_error(id: Option, error: E) -> Self { + StreamResponse { + id, + error: Some(format!("{error}")), + result: None, + partial_result: None, + } + } + + fn from_result(id: u64, result: &str) -> Self { + StreamResponse { + id: Some(id), + error: None, + result: Some(result.to_string()), + partial_result: None, + } + } + + fn from_partial_result(id: u64, result: &str) -> Self { + StreamResponse { + id: Some(id), + error: None, + result: None, + partial_result: Some(result.to_string()), + } + } +} + +#[get("/chatbot/clear-history")] +pub async fn clear_history(cookies: &CookieJar<'_>) -> Status { + // let cookie = Cookie::build("chatbot_session_id").path("/"); + let cookie = Cookie::new("chatbot_session_id", ""); + cookies.remove(cookie); + Status::Ok +} + +#[derive(Serialize)] +pub struct GetHistoryResponse { + result: Option>, + error: Option, +} + +#[derive(Serialize)] +struct HistoryMessage { + side: String, + content: String, + knowledge_base: String, + brain: String, +} + +#[get("/chatbot/get-history")] +pub async fn chatbot_get_history(user: User) -> Json { + match do_chatbot_get_history(&user, 100).await { + Ok(messages) => Json(GetHistoryResponse { + result: Some(messages), + error: None, + }), + Err(e) => Json(GetHistoryResponse { + result: None, + error: Some(format!("{e}")), + }), + } +} - let mut history_collection = Collection::new( +async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result> { + let history_collection = Collection::new( "ChatHistory", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), ); - let messages = history_collection + let mut messages = history_collection .get_documents(Some( json!({ - "limit": 5, + "limit": limit, "order_by": {"timestamp": "desc"}, "filter": { "metadata": { @@ -263,16 +427,6 @@ pub async fn wrapped_chatbot_get_answer( "user_id": { "$eq": user.chatbot_session_id } - }, - { - "knowledge_base": { - "$eq": knowledge_base - } - }, - { - "model": { - "$eq": brain - } } ] } @@ -282,24 +436,108 @@ pub async fn wrapped_chatbot_get_answer( .into(), )) .await?; - - let mut history = messages + messages.reverse(); + let messages: anyhow::Result> = messages .into_iter() .map(|m| { - // Can probably remove this clone - let chat_role: ChatRole = serde_json::from_value(m["document"]["role"].to_owned())?; - if chat_role == ChatRole::Bot { - Ok(format!("Assistant: {}", m["document"]["text"])) - } else { - Ok(format!("User: {}", m["document"]["text"])) - } + let side: String = m["document"]["role"] + .as_str() + .context("Error parsing chat role")? + .to_string() + .to_lowercase(); + let content: String = m["document"]["text"] + .as_str() + .context("Error parsing text")? + .to_string(); + let model: ChatbotBrain = serde_json::from_value(m["document"]["model"].to_owned()) + .context("Error parsing model")?; + let model: &str = model.into(); + let knowledge_base: KnowledgeBase = + serde_json::from_value(m["document"]["knowledge_base"].to_owned()) + .context("Error parsing knowledge_base")?; + let knowledge_base: &str = knowledge_base.into(); + Ok(HistoryMessage { + side, + content, + brain: model.to_string(), + knowledge_base: knowledge_base.to_string(), + }) }) - .collect::>>()?; - history.reverse(); - let history = history.join("\n"); + .collect(); + messages +} - let pipeline = Pipeline::new("v1", None, None, None); - let context = collection +#[get("/chatbot/get-answer")] +pub async fn chatbot_get_answer(user: User, ws: ws::WebSocket) -> ws::Stream!['static] { + ws::Stream! { ws => + for await message in ws { + let v = process_message(message, &user).await; + match v { + Ok((v, id)) => + match v { + ProcessMessageResponse::StreamResponse((mut it, update_history)) => { + let mut total_text: Vec = Vec::new(); + while let Some(value) = it.next().await { + match value { + Ok(v) => { + let v: &str = v["choices"][0]["delta"]["content"].as_str().unwrap(); + total_text.push(v.to_string()); + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_partial_result(id, v)).unwrap()); + }, + Err(e) => yield ws::Message::from(serde_json::to_string(&StreamResponse::from_error(Some(id), e)).unwrap()) + } + } + update_history.update_history(&total_text.join("")).unwrap(); + }, + ProcessMessageResponse::FullResponse(resp) => { + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_result(id, &resp)).unwrap()); + } + } + Err(e) => { + yield ws::Message::from(serde_json::to_string(&StreamResponse::from_error(None, e)).unwrap()); + } + } + }; + } +} + +enum ProcessMessageResponse { + StreamResponse((GeneralJsonAsyncIterator, UpdateHistory)), + FullResponse(String), +} + +#[derive(Deserialize)] +struct Message { + id: u64, + model: String, + knowledge_base: String, + question: String, +} + +async fn process_message( + message: Result, + user: &User, +) -> anyhow::Result<(ProcessMessageResponse, u64)> { + if let ws::Message::Text(s) = message? { + let data: Message = serde_json::from_str(&s)?; + let brain = ChatbotBrain::try_from(data.model.as_str())?; + let knowledge_base = KnowledgeBase::try_from(data.knowledge_base.as_str())?; + + let user_document = Document::new( + &data.question, + ChatRole::User, + user.chatbot_session_id.clone(), + brain, + knowledge_base, + ); + + let pipeline = Pipeline::new("v1", None, None, None); + let collection = knowledge_base.collection(); + let collection = Collection::new( + collection, + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + let context = collection .query() .vector_recall(&data.question, &pipeline, Some(json!({ "instruction": "Represent the Wikipedia question for retrieving supporting documents: " @@ -308,37 +546,152 @@ pub async fn wrapped_chatbot_get_answer( .fetch_all() .await? .into_iter() - .map(|(_, context, metadata)| format!("#### Document {}: {}", metadata["id"], context)) + .map(|(_, context, metadata)| format!("\n\n#### Document {}: \n{}\n\n", metadata["id"], context)) .collect::>() .join("\n"); - let answer = - get_openai_chatgpt_answer(knowledge_base, &history, &context, &data.question).await?; - - let new_history_messages: Vec = vec![ - serde_json::to_value(user_document).unwrap().into(), - serde_json::to_value(Document::new( - answer.clone(), - ChatRole::Bot, - user.chatbot_session_id.clone(), - brain, - knowledge_base, - )) - .unwrap() - .into(), - ]; - - // We do not want to block our return waiting for this to happen - tokio::spawn(async move { - history_collection - .upsert_documents(new_history_messages, None) - .await - .expect("Failed to upsert user history"); - }); + let history_collection = Collection::new( + "ChatHistory", + Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), + ); + let mut messages = history_collection + .get_documents(Some( + json!({ + "limit": 5, + "order_by": {"timestamp": "desc"}, + "filter": { + "metadata": { + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + // This is where we would match on the model if we wanted to + ] + } + } - Ok(answer) + }) + .into(), + )) + .await?; + messages.reverse(); + + let (mut history, _) = + messages + .into_iter() + .fold((Vec::new(), None), |(mut new_history, role), value| { + let current_role: ChatRole = + serde_json::from_value(value["document"]["role"].to_owned()) + .expect("Error parsing chat role"); + if let Some(role) = role { + if role == current_role { + match role { + ChatRole::User => new_history.push( + serde_json::json!({ + "role": ChatRole::Bot.to_model_specific_role(&brain), + "content": "*no response due to error*" + }) + .into(), + ), + ChatRole::Bot => new_history.push( + serde_json::json!({ + "role": ChatRole::User.to_model_specific_role(&brain), + "content": "*no response due to error*" + }) + .into(), + ), + _ => panic!("Too many system messages"), + } + } + let new_message: pgml::types::Json = serde_json::json!({ + "role": current_role.to_model_specific_role(&brain), + "content": value["document"]["text"] + }) + .into(); + new_history.push(new_message); + } else if matches!(current_role, ChatRole::User) { + let new_message: pgml::types::Json = serde_json::json!({ + "role": current_role.to_model_specific_role(&brain), + "content": value["document"]["text"] + }) + .into(); + new_history.push(new_message); + } + (new_history, Some(current_role)) + }); + + let system_message = brain.get_system_message(&knowledge_base, &context)?; + history.insert(0, system_message.into()); + + // Need to make sure we aren't about to add two user messages back to back + if let Some(message) = history.last() { + if message["role"].as_str().unwrap() == ChatRole::User.to_model_specific_role(&brain) { + history.push( + serde_json::json!({ + "role": ChatRole::Bot.to_model_specific_role(&brain), + "content": "*no response due to errors*" + }) + .into(), + ); + } + } + history.push( + serde_json::json!({ + "role": ChatRole::User.to_model_specific_role(&brain), + "content": data.question + }) + .into(), + ); + + let update_history = + UpdateHistory::new(history_collection, user_document, brain, knowledge_base); + + if brain.is_open_source() { + let op = OpenSourceAI::new(Some( + std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set"), + )); + let chat_template = brain.get_chat_template(); + let stream = op + .chat_completions_create_stream_async( + brain.into_model_json().into(), + history, + Some(10000), + None, + None, + chat_template.map(|t| t.to_string()), + ) + .await?; + Ok(( + ProcessMessageResponse::StreamResponse((stream, update_history)), + data.id, + )) + } else { + let response = match brain { + ChatbotBrain::OpenAIGPT4 => get_openai_chatgpt_answer(history).await?, + _ => unimplemented!(), + }; + update_history.update_history(&response)?; + Ok((ProcessMessageResponse::FullResponse(response), data.id)) + } + } else { + Err(anyhow::anyhow!("Error invalid message format")) + } } pub fn routes() -> Vec { - routes![chatbot_get_answer] + routes![chatbot_get_answer, chatbot_get_history, clear_history] } diff --git a/pgml-dashboard/src/components/chatbot/chatbot.scss b/pgml-dashboard/src/components/chatbot/chatbot.scss index e4bc2f723..a8b934dd5 100644 --- a/pgml-dashboard/src/components/chatbot/chatbot.scss +++ b/pgml-dashboard/src/components/chatbot/chatbot.scss @@ -19,6 +19,7 @@ div[data-controller="chatbot"] { #chatbot-change-the-brain-title, #knowledge-base-title { + font-size: 1.25rem; padding: 0.5rem; padding-top: 0.85rem; margin-bottom: 1rem; @@ -30,6 +31,7 @@ div[data-controller="chatbot"] { margin-top: calc($spacer * 4); } + div[data-chatbot-target="clear"], .chatbot-brain-option-label, .chatbot-knowledge-base-option-label { cursor: pointer; @@ -37,7 +39,7 @@ div[data-controller="chatbot"] { transition: all 0.1s; } - .chatbot-brain-option-label:hover { + .chatbot-brain-option-label:hover, div[data-chatbot-target="clear"]:hover { background-color: #{$gray-800}; } @@ -59,8 +61,8 @@ div[data-controller="chatbot"] { } .chatbot-brain-option-logo { - height: 30px; width: 30px; + height: 30px; background-position: center; background-repeat: no-repeat; background-size: contain; @@ -70,6 +72,14 @@ div[data-controller="chatbot"] { padding-left: 2rem; } + #brain-knowledge-base-divider-line { + height: 0.15rem; + width: 100%; + background-color: #{$gray-500}; + margin-top: 1.5rem; + margin-bottom: 1.5rem; + } + .chatbot-example-questions { display: none; max-height: 66px; @@ -299,4 +309,10 @@ div[data-controller="chatbot"].chatbot-full { #knowledge-base-wrapper { display: block; } + #brain-knowledge-base-divider-line { + display: none; + } + #clear-history-text { + display: block !important; + } } diff --git a/pgml-dashboard/src/components/chatbot/chatbot_controller.js b/pgml-dashboard/src/components/chatbot/chatbot_controller.js index ef6703b33..d6240c645 100644 --- a/pgml-dashboard/src/components/chatbot/chatbot_controller.js +++ b/pgml-dashboard/src/components/chatbot/chatbot_controller.js @@ -4,6 +4,10 @@ import autosize from "autosize"; import DOMPurify from "dompurify"; import * as marked from "marked"; +const getRandomInt = () => { + return Math.floor(Math.random() * Number.MAX_SAFE_INTEGER); +} + const LOADING_MESSAGE = `
Loading
@@ -11,40 +15,44 @@ const LOADING_MESSAGE = `
`; -const getBackgroundImageURLForSide = (side, knowledgeBase) => { +const getBackgroundImageURLForSide = (side, brain) => { if (side == "user") { return "/dashboard/static/images/chatbot_user.webp"; } else { - if (knowledgeBase == 0) { - return "/dashboard/static/images/owl_gradient.svg"; - } else if (knowledgeBase == 1) { - return "/dashboard/static/images/logos/pytorch.svg"; - } else if (knowledgeBase == 2) { - return "/dashboard/static/images/logos/rust.svg"; - } else if (knowledgeBase == 3) { - return "/dashboard/static/images/logos/postgresql.svg"; + if (brain == "teknium/OpenHermes-2.5-Mistral-7B") { + return "/dashboard/static/images/logos/openhermes.webp" + } else if (brain == "Gryphe/MythoMax-L2-13b") { + return "/dashboard/static/images/logos/mythomax.webp" + } else if (brain == "berkeley-nest/Starling-LM-7B-alpha") { + return "/dashboard/static/images/logos/starling.webp" + } else if (brain == "openai") { + return "/dashboard/static/images/logos/openai.webp" } } }; -const createHistoryMessage = (side, question, id, knowledgeBase) => { - id = id || ""; +const createHistoryMessage = (message) => { + if (message.side == "system") { + return ` +
${message.text}
+ `; + } return ` -
-
- ${question} +
+ ${message.get_html()}
@@ -52,17 +60,29 @@ const createHistoryMessage = (side, question, id, knowledgeBase) => { }; const knowledgeBaseIdToName = (knowledgeBase) => { - if (knowledgeBase == 0) { + if (knowledgeBase == "postgresml") { return "PostgresML"; - } else if (knowledgeBase == 1) { + } else if (knowledgeBase == "pytorch") { return "PyTorch"; - } else if (knowledgeBase == 2) { + } else if (knowledgeBase == "rust") { return "Rust"; - } else if (knowledgeBase == 3) { + } else if (knowledgeBase == "postgresql") { return "PostgreSQL"; } }; +const brainIdToName = (brain) => { + if (brain == "teknium/OpenHermes-2.5-Mistral-7B") { + return "OpenHermes" + } else if (brain == "Gryphe/MythoMax-L2-13b") { + return "MythoMax" + } else if (brain == "berkeley-nest/Starling-LM-7B-alpha") { + return "Starling" + } else if (brain == "openai") { + return "ChatGPT" + } +} + const createKnowledgeBaseNotice = (knowledgeBase) => { return `
Chatting with Knowledge Base ${knowledgeBaseIdToName( @@ -71,21 +91,72 @@ const createKnowledgeBaseNotice = (knowledgeBase) => { `; }; -const getAnswer = async (question, model, knowledgeBase) => { - const response = await fetch("/chatbot/get-answer", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ question, model, knowledgeBase }), - }); - return response.json(); -}; +class Message { + constructor(id, side, brain, text, is_partial=false) { + this.id = id + this.side = side + this.brain = brain + this.text = text + this.is_partial = is_partial + } + + get_html() { + return DOMPurify.sanitize(marked.parse(this.text)); + } +} + +class RawMessage extends Message { + constructor(id, side, text, is_partial=false) { + super(id, side, text, is_partial); + } + + get_html() { + return this.text; + } +} + +class MessageHistory { + constructor() { + this.messageHistory = {}; + } + + add_message(message, knowledgeBase) { + console.log("ADDDING", message, knowledgeBase); + if (!(knowledgeBase in this.messageHistory)) { + this.messageHistory[knowledgeBase] = []; + } + if (message.is_partial) { + let current_message = this.messageHistory[knowledgeBase].find(item => item.id == message.id); + if (!current_message) { + this.messageHistory[knowledgeBase].push(message); + } else { + current_message.text += message.text; + } + } else { + if (this.messageHistory[knowledgeBase].length == 0 || message.side != "system") { + this.messageHistory[knowledgeBase].push(message); + } else if (this.messageHistory[knowledgeBase][this.messageHistory[knowledgeBase].length -1].side == "system") { + this.messageHistory[knowledgeBase][this.messageHistory[knowledgeBase].length -1] = message + } else { + this.messageHistory[knowledgeBase].push(message); + } + } + } + + get_messages(knowledgeBase) { + if (!(knowledgeBase in this.messageHistory)) { + return []; + } else { + return this.messageHistory[knowledgeBase]; + } + } +} export default class extends Controller { initialize() { - this.alertCount = 0; - this.gettingAnswer = false; + this.messageHistory = new MessageHistory(); + this.messageIdToKnowledgeBaseId = {}; + this.expanded = false; this.chatbot = document.getElementById("chatbot"); this.expandContractImage = document.getElementById( @@ -100,55 +171,105 @@ export default class extends Controller { this.exampleQuestions = document.getElementsByClassName( "chatbot-example-questions", ); - this.handleBrainChange(); // This will set our initial brain this.handleKnowledgeBaseChange(); // This will set our initial knowledge base + this.handleBrainChange(); // This will set our initial brain this.handleResize(); + + const url = ((window.location.protocol === "https:") ? "wss://" : "ws://") + window.location.hostname + (((window.location.port != 80) && (window.location.port != 443)) ? ":" + window.location.port : "") + window.location.pathname + "/get-answer"; + this.socket = new WebSocket(url); + this.socket.onmessage = (message) => { + let result = JSON.parse(message.data); + console.log(result); + + if (result.error) { + this.showChatbotAlert("Error", "Error getting chatbot answer"); + console.log(result.error); + this.redrawChat(); // This clears any loading messages + } else { + let message; + if (result.partial_result) { + message = new Message(result.id, "bot", this.brain, result.partial_result, true); + } else { + message = new Message(result.id, "bot", this.brain, result.result); + } + this.messageHistory.add_message(message, this.messageIdToKnowledgeBaseId[message.id]); + this.redrawChat(); + } + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + }; + + this.socket.onclose = () => { + window.setTimeout(() => this.openConnection(), 500); + }; + this.getHistory(); + } + + async clearHistory() { + // This endpoint clears the chatbot_sesion_id cookie + await fetch("/chatbot/clear-history"); + window.location.reload(); + } + + async getHistory() { + const result = await fetch("/chatbot/get-history"); + const history = await result.json(); + if (history.error) { + console.log("Error getting chat history", history.error) + } else { + for (const message of history.result) { + const newMessage = new Message(getRandomInt(), message.side, message.brain, message.content, false); + console.log(newMessage); + this.messageHistory.add_message(newMessage, message.knowledge_base); + } + } + this.redrawChat(); + } + + redrawChat() { + this.chatHistory.innerHTML = ""; + const messages = this.messageHistory.get_messages(this.knowledgeBase); + for (const message of messages) { + console.log("Drawing", message); + this.chatHistory.insertAdjacentHTML( + "beforeend", + createHistoryMessage(message), + ); + } + + // Hide or show example questions + this.hideExampleQuestions(); + if (messages.length == 0 || (messages.length == 1 && messages[0].side == "system")) { + document + .getElementById(`chatbot-example-questions-${this.knowledgeBase}`) + .style.setProperty("display", "flex", "important"); + } + + this.chatHistory.scrollTop = this.chatHistory.scrollHeight; } newUserQuestion(question) { + const message = new Message(getRandomInt(), "user", this.brain, question); + this.messageHistory.add_message(message, this.knowledgeBase); + this.messageIdToKnowledgeBaseId[message.id] = this.knowledgeBase; + this.hideExampleQuestions(); + this.redrawChat(); + + let loadingMessage = new Message("loading", "bot", this.brain, LOADING_MESSAGE); this.chatHistory.insertAdjacentHTML( "beforeend", - createHistoryMessage("user", question), - ); - this.chatHistory.insertAdjacentHTML( - "beforeend", - createHistoryMessage( - "bot", - LOADING_MESSAGE, - "chatbot-loading-message", - this.knowledgeBase, - ), + createHistoryMessage(loadingMessage), ); - this.hideExampleQuestions(); this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - - this.gettingAnswer = true; - getAnswer(question, this.brain, this.knowledgeBase) - .then((answer) => { - if (answer.answer) { - this.chatHistory.insertAdjacentHTML( - "beforeend", - createHistoryMessage( - "bot", - DOMPurify.sanitize(marked.parse(answer.answer)), - "", - this.knowledgeBase, - ), - ); - } else { - this.showChatbotAlert("Error", answer.error); - console.log(answer.error); - } - }) - .catch((error) => { - this.showChatbotAlert("Error", "Error getting chatbot answer"); - console.log(error); - }) - .finally(() => { - document.getElementById("chatbot-loading-message").remove(); - this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - this.gettingAnswer = false; - }); + + let id = getRandomInt(); + this.messageIdToKnowledgeBaseId[id] = this.knowledgeBase; + let socketData = { + id, + question, + model: this.brain, + knowledge_base: this.knowledgeBase + }; + this.socket.send(JSON.stringify(socketData)); } handleResize() { @@ -169,12 +290,10 @@ export default class extends Controller { handleEnter(e) { // This prevents adding a return e.preventDefault(); - + // Don't continue if the question is empty const question = this.questionInput.value.trim(); - if (question.length == 0) { + if (question.length == 0) return; - } - // Handle resetting the input // There is probably a better way to do this, but this was the best/easiest I found this.questionInput.value = ""; @@ -185,105 +304,31 @@ export default class extends Controller { } handleBrainChange() { - // Comment this out when we go back to using brains - this.brain = 0; + let selected = document.querySelector('input[name="chatbot-brain-options"]:checked').value; + if (selected == this.brain) + return; + this.brain = selected; this.questionInput.focus(); - - // Uncomment this out when we go back to using brains - // We could just disable the input, but we would then need to listen for click events so this seems easier - // if (this.gettingAnswer) { - // document.querySelector( - // `input[name="chatbot-brain-options"][value="${this.brain}"]`, - // ).checked = true; - // this.showChatbotAlert( - // "Error", - // "Cannot change brain while chatbot is loading answer", - // ); - // return; - // } - // let selected = parseInt( - // document.querySelector('input[name="chatbot-brain-options"]:checked') - // .value, - // ); - // if (selected == this.brain) { - // return; - // } - // brainToContentMap[this.brain] = this.chatHistory.innerHTML; - // this.chatHistory.innerHTML = brainToContentMap[selected] || ""; - // if (this.chatHistory.innerHTML) { - // this.exampleQuestions.style.setProperty("display", "none", "important"); - // } else { - // this.exampleQuestions.style.setProperty("display", "flex", "important"); - // } - // this.brain = selected; - // this.chatHistory.scrollTop = this.chatHistory.scrollHeight; - // this.questionInput.focus(); + this.addBrainAndKnowledgeBaseChangedSystemMessage(); } handleKnowledgeBaseChange() { - // Uncomment this when we go back to using brains - // let selected = parseInt( - // document.querySelector('input[name="chatbot-knowledge-base-options"]:checked') - // .value, - // ); - // this.knowledgeBase = selected; - - // Comment this out when we go back to using brains - // We could just disable the input, but we would then need to listen for click events so this seems easier - if (this.gettingAnswer) { - document.querySelector( - `input[name="chatbot-knowledge-base-options"][value="${this.knowledgeBase}"]`, - ).checked = true; - this.showChatbotAlert( - "Error", - "Cannot change knowledge base while chatbot is loading answer", - ); - return; - } - let selected = parseInt( - document.querySelector( - 'input[name="chatbot-knowledge-base-options"]:checked', - ).value, - ); - if (selected == this.knowledgeBase) { + let selected = document.querySelector('input[name="chatbot-knowledge-base-options"]:checked').value; + if (selected == this.knowledgeBase) return; - } - - // document.getElementById - this.knowledgeBaseToContentMap[this.knowledgeBase] = - this.chatHistory.innerHTML; - this.chatHistory.innerHTML = this.knowledgeBaseToContentMap[selected] || ""; this.knowledgeBase = selected; - - // This should be extended to insert the new knowledge base notice in the correct place - if (this.chatHistory.childElementCount == 0) { - this.chatHistory.insertAdjacentHTML( - "beforeend", - createKnowledgeBaseNotice(this.knowledgeBase), - ); - this.hideExampleQuestions(); - document - .getElementById( - `chatbot-example-questions-${knowledgeBaseIdToName( - this.knowledgeBase, - )}`, - ) - .style.setProperty("display", "flex", "important"); - } else if (this.chatHistory.childElementCount == 1) { - this.hideExampleQuestions(); - document - .getElementById( - `chatbot-example-questions-${knowledgeBaseIdToName( - this.knowledgeBase, - )}`, - ) - .style.setProperty("display", "flex", "important"); - } else { - this.hideExampleQuestions(); - } - - this.chatHistory.scrollTop = this.chatHistory.scrollHeight; + this.redrawChat(); this.questionInput.focus(); + this.addBrainAndKnowledgeBaseChangedSystemMessage(); + } + + addBrainAndKnowledgeBaseChangedSystemMessage() { + let knowledge_base = knowledgeBaseIdToName(this.knowledgeBase); + let brain = brainIdToName(this.brain); + let content = `Chatting with ${brain} about ${knowledge_base}`; + const newMessage = new Message(getRandomInt(), "system", this.brain, content); + this.messageHistory.add_message(newMessage, this.knowledgeBase); + this.redrawChat(); } handleExampleQuestionClick(e) { diff --git a/pgml-dashboard/src/components/chatbot/mod.rs b/pgml-dashboard/src/components/chatbot/mod.rs index 8bcf23fc4..4b149b96e 100644 --- a/pgml-dashboard/src/components/chatbot/mod.rs +++ b/pgml-dashboard/src/components/chatbot/mod.rs @@ -4,7 +4,7 @@ use sailfish::TemplateOnce; type ExampleQuestions = [(&'static str, [(&'static str, &'static str); 4]); 4]; const EXAMPLE_QUESTIONS: ExampleQuestions = [ ( - "PostgresML", + "postgresml", [ ("How do I", "use pgml.transform()?"), ("Show me", "a query to train a model"), @@ -13,7 +13,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "PyTorch", + "pytorch", [ ("What are", "tensors?"), ("How do I", "train a model?"), @@ -22,7 +22,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "Rust", + "rust", [ ("What is", "a lifetime?"), ("How do I", "use a for loop?"), @@ -31,7 +31,7 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ], ), ( - "PostgreSQL", + "postgresql", [ ("How do I", "join two tables?"), ("What is", "a GIN index?"), @@ -41,79 +41,92 @@ const EXAMPLE_QUESTIONS: ExampleQuestions = [ ), ]; -const KNOWLEDGE_BASES: [&str; 0] = [ - // "Knowledge Base 1", - // "Knowledge Base 2", - // "Knowledge Base 3", - // "Knowledge Base 4", -]; - const KNOWLEDGE_BASES_WITH_LOGO: [KnowledgeBaseWithLogo; 4] = [ - KnowledgeBaseWithLogo::new("PostgresML", "/dashboard/static/images/owl_gradient.svg"), - KnowledgeBaseWithLogo::new("PyTorch", "/dashboard/static/images/logos/pytorch.svg"), - KnowledgeBaseWithLogo::new("Rust", "/dashboard/static/images/logos/rust.svg"), KnowledgeBaseWithLogo::new( + "postgresml", + "PostgresML", + "/dashboard/static/images/owl_gradient.svg", + ), + KnowledgeBaseWithLogo::new( + "pytorch", + "PyTorch", + "/dashboard/static/images/logos/pytorch.svg", + ), + KnowledgeBaseWithLogo::new("rust", "Rust", "/dashboard/static/images/logos/rust.svg"), + KnowledgeBaseWithLogo::new( + "postgresql", "PostgreSQL", "/dashboard/static/images/logos/postgresql.svg", ), ]; struct KnowledgeBaseWithLogo { + id: &'static str, name: &'static str, logo: &'static str, } impl KnowledgeBaseWithLogo { - const fn new(name: &'static str, logo: &'static str) -> Self { - Self { name, logo } + const fn new(id: &'static str, name: &'static str, logo: &'static str) -> Self { + Self { id, name, logo } } } -const CHATBOT_BRAINS: [ChatbotBrain; 0] = [ - // ChatbotBrain::new( - // "PostgresML", - // "Falcon 180b", - // "/dashboard/static/images/owl_gradient.svg", - // ), - // ChatbotBrain::new( - // "OpenAI", - // "ChatGPT", - // "/dashboard/static/images/logos/openai.webp", - // ), - // ChatbotBrain::new( - // "Anthropic", - // "Claude", - // "/dashboard/static/images/logos/anthropic.webp", - // ), - // ChatbotBrain::new( - // "Meta", - // "Llama2 70b", - // "/dashboard/static/images/logos/meta.webp", - // ), +const CHATBOT_BRAINS: [ChatbotBrain; 4] = [ + ChatbotBrain::new( + "teknium/OpenHermes-2.5-Mistral-7B", + "OpenHermes", + "teknium/OpenHermes-2.5-Mistral-7B", + "/dashboard/static/images/logos/openhermes.webp", + ), + ChatbotBrain::new( + "Gryphe/MythoMax-L2-13b", + "MythoMax", + "Gryphe/MythoMax-L2-13b", + "/dashboard/static/images/logos/mythomax.webp", + ), + ChatbotBrain::new( + "openai", + "OpenAI", + "ChatGPT", + "/dashboard/static/images/logos/openai.webp", + ), + ChatbotBrain::new( + "berkeley-nest/Starling-LM-7B-alpha", + "Starling", + "berkeley-nest/Starling-LM-7B-alpha", + "/dashboard/static/images/logos/starling.webp", + ), ]; struct ChatbotBrain { + id: &'static str, provider: &'static str, model: &'static str, logo: &'static str, } -// impl ChatbotBrain { -// const fn new(provider: &'static str, model: &'static str, logo: &'static str) -> Self { -// Self { -// provider, -// model, -// logo, -// } -// } -// } +impl ChatbotBrain { + const fn new( + id: &'static str, + provider: &'static str, + model: &'static str, + logo: &'static str, + ) -> Self { + Self { + id, + provider, + model, + logo, + } + } +} #[derive(TemplateOnce)] #[template(path = "chatbot/template.html")] pub struct Chatbot { - brains: &'static [ChatbotBrain; 0], + brains: &'static [ChatbotBrain; 4], example_questions: &'static ExampleQuestions, - knowledge_bases: &'static [&'static str; 0], knowledge_bases_with_logo: &'static [KnowledgeBaseWithLogo; 4], } @@ -122,7 +135,6 @@ impl Default for Chatbot { Chatbot { brains: &CHATBOT_BRAINS, example_questions: &EXAMPLE_QUESTIONS, - knowledge_bases: &KNOWLEDGE_BASES, knowledge_bases_with_logo: &KNOWLEDGE_BASES_WITH_LOGO, } } diff --git a/pgml-dashboard/src/components/chatbot/template.html b/pgml-dashboard/src/components/chatbot/template.html index 1f47cf865..9da069cce 100644 --- a/pgml-dashboard/src/components/chatbot/template.html +++ b/pgml-dashboard/src/components/chatbot/template.html @@ -1,102 +1,72 @@
-
+
- -
Knowledge Base:
+
Change the Brain:
- <% for (index, knowledge_base) in knowledge_bases_with_logo.iter().enumerate() { %> + <% for (index, brain) in brains.iter().enumerate() { %>
checked <% } %> />
<% } %> - - - -
diff --git a/pgml-dashboard/static/images/logos/mythomax.webp b/pgml-dashboard/static/images/logos/mythomax.webp new file mode 100644 index 000000000..6e6c363b2 Binary files /dev/null and b/pgml-dashboard/static/images/logos/mythomax.webp differ diff --git a/pgml-dashboard/static/images/logos/openhermes.webp b/pgml-dashboard/static/images/logos/openhermes.webp new file mode 100644 index 000000000..3c202681e Binary files /dev/null and b/pgml-dashboard/static/images/logos/openhermes.webp differ diff --git a/pgml-dashboard/static/images/logos/starling.webp b/pgml-dashboard/static/images/logos/starling.webp new file mode 100644 index 000000000..988696b14 Binary files /dev/null and b/pgml-dashboard/static/images/logos/starling.webp differ