Skip to content

Commit 21ad471

Browse files
authored
Fix panic when cluster isnt connected (#895)
1 parent 5d5b826 commit 21ad471

File tree

2 files changed

+55
-20
lines changed

2 files changed

+55
-20
lines changed

pgml-dashboard/src/guards.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,37 @@ impl<'r> FromRequest<'r> for &'r Cluster {
8080
}
8181
}
8282

83+
#[derive(Debug)]
84+
pub struct ConnectedCluster<'a> {
85+
pub inner: &'a Cluster,
86+
}
87+
88+
#[rocket::async_trait]
89+
impl<'r> FromRequest<'r> for ConnectedCluster<'r> {
90+
type Error = ();
91+
92+
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
93+
let cluster = match request.guard::<&Cluster>().await {
94+
request::Outcome::Success(cluster) => cluster,
95+
_ => return request::Outcome::Forward(()),
96+
};
97+
98+
if cluster.pool.as_ref().is_some() {
99+
request::Outcome::Success(ConnectedCluster { inner: cluster })
100+
} else {
101+
request::Outcome::Forward(())
102+
}
103+
}
104+
}
105+
83106
impl<'a> Cluster {
84107
pub fn pool(&'a self) -> &'a PgPool {
85108
self.pool.as_ref().unwrap()
86109
}
87110
}
111+
112+
impl<'a> ConnectedCluster<'_> {
113+
pub fn pool(&'a self) -> &'a PgPool {
114+
self.inner.pool()
115+
}
116+
}

pgml-dashboard/src/lib.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub mod responses;
1717
pub mod templates;
1818
pub mod utils;
1919

20-
use guards::Cluster;
20+
use guards::{Cluster, ConnectedCluster};
2121
use responses::{BadRequest, Error, ResponseOk};
2222
use templates::{
2323
components::StaticNav, DeploymentsTab, Layout, ModelsTab, NotebooksTab, ProjectsTab,
@@ -47,7 +47,7 @@ pub struct Context {
4747
}
4848

4949
#[get("/projects")]
50-
pub async fn project_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
50+
pub async fn project_index(cluster: ConnectedCluster<'_>) -> Result<ResponseOk, Error> {
5151
Ok(ResponseOk(
5252
templates::Projects {
5353
projects: models::Project::all(cluster.pool()).await?,
@@ -58,7 +58,7 @@ pub async fn project_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
5858
}
5959

6060
#[get("/projects/<id>")]
61-
pub async fn project_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error> {
61+
pub async fn project_get(cluster: ConnectedCluster<'_>, id: i64) -> Result<ResponseOk, Error> {
6262
let project = models::Project::get_by_id(cluster.pool(), id).await?;
6363
let models = models::Model::get_by_project_id(cluster.pool(), id).await?;
6464

@@ -70,7 +70,7 @@ pub async fn project_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error
7070
}
7171

7272
#[get("/notebooks")]
73-
pub async fn notebook_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
73+
pub async fn notebook_index(cluster: ConnectedCluster<'_>) -> Result<ResponseOk, Error> {
7474
Ok(ResponseOk(
7575
templates::Notebooks {
7676
notebooks: models::Notebook::all(&cluster.pool()).await?,
@@ -94,7 +94,10 @@ pub async fn notebook_create(
9494
}
9595

9696
#[get("/notebooks/<notebook_id>")]
97-
pub async fn notebook_get(cluster: &Cluster, notebook_id: i64) -> Result<ResponseOk, Error> {
97+
pub async fn notebook_get(
98+
cluster: ConnectedCluster<'_>,
99+
notebook_id: i64,
100+
) -> Result<ResponseOk, Error> {
98101
let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
99102

100103
Ok(ResponseOk(Layout::new("Notebooks").render(
@@ -106,7 +109,10 @@ pub async fn notebook_get(cluster: &Cluster, notebook_id: i64) -> Result<Respons
106109
}
107110

108111
#[post("/notebooks/<notebook_id>/reset")]
109-
pub async fn notebook_reset(cluster: &Cluster, notebook_id: i64) -> Result<Redirect, Error> {
112+
pub async fn notebook_reset(
113+
cluster: ConnectedCluster<'_>,
114+
notebook_id: i64,
115+
) -> Result<Redirect, Error> {
110116
let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
111117
notebook.reset(cluster.pool()).await?;
112118

@@ -140,7 +146,7 @@ pub async fn cell_create(
140146

141147
#[get("/notebooks/<notebook_id>/cell/<cell_id>")]
142148
pub async fn cell_get(
143-
cluster: &Cluster,
149+
cluster: ConnectedCluster<'_>,
144150
notebook_id: i64,
145151
cell_id: i64,
146152
) -> Result<ResponseOk, Error> {
@@ -167,7 +173,7 @@ pub async fn cell_get(
167173

168174
#[post("/notebooks/<notebook_id>/cell/<cell_id>/edit", data = "<data>")]
169175
pub async fn cell_edit(
170-
cluster: &Cluster,
176+
cluster: ConnectedCluster<'_>,
171177
notebook_id: i64,
172178
cell_id: i64,
173179
data: Form<forms::Cell<'_>>,
@@ -203,7 +209,7 @@ pub async fn cell_edit(
203209

204210
#[get("/notebooks/<notebook_id>/cell/<cell_id>/edit")]
205211
pub async fn cell_trigger_edit(
206-
cluster: &Cluster,
212+
cluster: ConnectedCluster<'_>,
207213
notebook_id: i64,
208214
cell_id: i64,
209215
) -> Result<ResponseOk, Error> {
@@ -229,7 +235,7 @@ pub async fn cell_trigger_edit(
229235

230236
#[post("/notebooks/<notebook_id>/cell/<cell_id>/play")]
231237
pub async fn cell_play(
232-
cluster: &Cluster,
238+
cluster: ConnectedCluster<'_>,
233239
notebook_id: i64,
234240
cell_id: i64,
235241
) -> Result<ResponseOk, Error> {
@@ -256,7 +262,7 @@ pub async fn cell_play(
256262

257263
#[post("/notebooks/<notebook_id>/cell/<cell_id>/remove")]
258264
pub async fn cell_remove(
259-
cluster: &Cluster,
265+
cluster: ConnectedCluster<'_>,
260266
notebook_id: i64,
261267
cell_id: i64,
262268
) -> Result<ResponseOk, Error> {
@@ -279,7 +285,7 @@ pub async fn cell_remove(
279285

280286
#[post("/notebooks/<notebook_id>/cell/<cell_id>/delete")]
281287
pub async fn cell_delete(
282-
cluster: &Cluster,
288+
cluster: ConnectedCluster<'_>,
283289
notebook_id: i64,
284290
cell_id: i64,
285291
) -> Result<Redirect, Error> {
@@ -295,7 +301,7 @@ pub async fn cell_delete(
295301
}
296302

297303
#[get("/models")]
298-
pub async fn models_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
304+
pub async fn models_index(cluster: ConnectedCluster<'_>) -> Result<ResponseOk, Error> {
299305
let projects = models::Project::all(cluster.pool()).await?;
300306
let mut models = HashMap::new();
301307
// let mut max_scores = HashMap::new();
@@ -328,7 +334,7 @@ pub async fn models_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
328334
}
329335

330336
#[get("/models/<id>")]
331-
pub async fn models_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error> {
337+
pub async fn models_get(cluster: ConnectedCluster<'_>, id: i64) -> Result<ResponseOk, Error> {
332338
let model = models::Model::get_by_id(cluster.pool(), id).await?;
333339
let snapshot = models::Snapshot::get_by_id(cluster.pool(), model.snapshot_id).await?;
334340
let project = models::Project::get_by_id(cluster.pool(), model.project_id).await?;
@@ -346,7 +352,7 @@ pub async fn models_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error>
346352
}
347353

348354
#[get("/snapshots")]
349-
pub async fn snapshots_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
355+
pub async fn snapshots_index(cluster: ConnectedCluster<'_>) -> Result<ResponseOk, Error> {
350356
let snapshots = models::Snapshot::all(cluster.pool()).await?;
351357

352358
Ok(ResponseOk(
@@ -355,7 +361,7 @@ pub async fn snapshots_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
355361
}
356362

357363
#[get("/snapshots/<id>")]
358-
pub async fn snapshots_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error> {
364+
pub async fn snapshots_get(cluster: ConnectedCluster<'_>, id: i64) -> Result<ResponseOk, Error> {
359365
let snapshot = models::Snapshot::get_by_id(cluster.pool(), id).await?;
360366
let samples = snapshot.samples(cluster.pool(), 500).await?;
361367

@@ -379,7 +385,7 @@ pub async fn snapshots_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Err
379385
}
380386

381387
#[get("/deployments")]
382-
pub async fn deployments_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
388+
pub async fn deployments_index(cluster: ConnectedCluster<'_>) -> Result<ResponseOk, Error> {
383389
let projects = models::Project::all(cluster.pool()).await?;
384390
let mut deployments = HashMap::new();
385391

@@ -401,7 +407,7 @@ pub async fn deployments_index(cluster: &Cluster) -> Result<ResponseOk, Error> {
401407
}
402408

403409
#[get("/deployments/<id>")]
404-
pub async fn deployments_get(cluster: &Cluster, id: i64) -> Result<ResponseOk, Error> {
410+
pub async fn deployments_get(cluster: ConnectedCluster<'_>, id: i64) -> Result<ResponseOk, Error> {
405411
let deployment = models::Deployment::get_by_id(cluster.pool(), id).await?;
406412
let project = models::Project::get_by_id(cluster.pool(), deployment.project_id).await?;
407413
let model = models::Model::get_by_id(cluster.pool(), deployment.model_id).await?;
@@ -424,7 +430,7 @@ pub async fn uploader_index() -> ResponseOk {
424430

425431
#[post("/uploader", data = "<form>")]
426432
pub async fn uploader_upload(
427-
cluster: &Cluster,
433+
cluster: ConnectedCluster<'_>,
428434
form: Form<forms::Upload<'_>>,
429435
) -> Result<Redirect, BadRequest> {
430436
let mut uploaded_file = models::UploadedFile::create(cluster.pool()).await.unwrap();
@@ -446,7 +452,7 @@ pub async fn uploader_upload(
446452
}
447453

448454
#[get("/uploader/done?<table_name>")]
449-
pub async fn uploaded_index(cluster: &Cluster, table_name: &str) -> ResponseOk {
455+
pub async fn uploaded_index(cluster: ConnectedCluster<'_>, table_name: &str) -> ResponseOk {
450456
let sql = templates::Sql::new(
451457
cluster.pool(),
452458
&format!("SELECT * FROM {} LIMIT 10", table_name),

0 commit comments

Comments
 (0)