Skip to content

Commit 9de8824

Browse files
authored
Introduce a new GatewayHandle wrapper type (#2965)
* Introduce a new `GatewayHandle` wrapper type This type is not cloneable, and wraps `AppStateData` (which remains cloneable). When we add support for ClickHouse write batching, the Drop impl for `GatewayHandle` will block and wait for the batcher to exit, to ensure that we don't lose writes during a graceful shutdown. See the comments on `GatewayHandle` for more details * Fix clippy
1 parent 42a91db commit 9de8824

File tree

8 files changed

+164
-91
lines changed

8 files changed

+164
-91
lines changed

clients/rust/src/lib.rs

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use tensorzero_core::{
2626
validate_tags,
2727
},
2828
error::{Error, ErrorDetails},
29-
gateway_util::{setup_clickhouse, setup_http_client, AppStateData},
29+
gateway_util::{setup_clickhouse, setup_http_client, GatewayHandle},
3030
};
3131
use thiserror::Error;
3232
use tokio::{sync::Mutex, time::error::Elapsed};
@@ -122,7 +122,7 @@ impl HTTPGateway {
122122
}
123123

124124
struct EmbeddedGateway {
125-
state: AppStateData,
125+
handle: GatewayHandle,
126126
}
127127

128128
/// Used to construct a `Client`
@@ -301,7 +301,7 @@ impl ClientBuilder {
301301
Ok(Client {
302302
mode: ClientMode::EmbeddedGateway {
303303
gateway: EmbeddedGateway {
304-
state: AppStateData::new_with_clickhouse_and_http_client(
304+
handle: GatewayHandle::new_with_clickhouse_and_http_client(
305305
config,
306306
clickhouse_connection_info,
307307
http_client,
@@ -318,10 +318,10 @@ impl ClientBuilder {
318318
}
319319

320320
#[cfg(any(test, feature = "e2e_tests"))]
321-
pub async fn build_from_state(state: AppStateData) -> Result<Client, ClientBuilderError> {
321+
pub async fn build_from_state(handle: GatewayHandle) -> Result<Client, ClientBuilderError> {
322322
Ok(Client {
323323
mode: ClientMode::EmbeddedGateway {
324-
gateway: EmbeddedGateway { state },
324+
gateway: EmbeddedGateway { handle },
325325
timeout: None,
326326
},
327327
verbose_errors: false,
@@ -374,7 +374,8 @@ impl Client {
374374
gateway,
375375
timeout: _,
376376
} => gateway
377-
.state
377+
.handle
378+
.app_state
378379
.clickhouse_connection_info
379380
.health()
380381
.await
@@ -406,9 +407,12 @@ impl Client {
406407
}
407408
ClientMode::EmbeddedGateway { gateway, timeout } => {
408409
Ok(with_embedded_timeout(*timeout, async {
409-
tensorzero_core::endpoints::feedback::feedback(gateway.state.clone(), params)
410-
.await
411-
.map_err(err_to_http)
410+
tensorzero_core::endpoints::feedback::feedback(
411+
gateway.handle.app_state.clone(),
412+
params,
413+
)
414+
.await
415+
.map_err(err_to_http)
412416
})
413417
.await?)
414418
}
@@ -482,9 +486,9 @@ impl Client {
482486
ClientMode::EmbeddedGateway { gateway, timeout } => {
483487
Ok(with_embedded_timeout(*timeout, async {
484488
tensorzero_core::endpoints::inference::inference(
485-
gateway.state.config.clone(),
486-
&gateway.state.http_client,
487-
gateway.state.clickhouse_connection_info.clone(),
489+
gateway.handle.app_state.config.clone(),
490+
&gateway.handle.app_state.http_client,
491+
gateway.handle.app_state.clickhouse_connection_info.clone(),
488492
params.try_into().map_err(err_to_http)?,
489493
)
490494
.await
@@ -530,7 +534,7 @@ impl Client {
530534
ClientMode::EmbeddedGateway { gateway, timeout } => {
531535
Ok(with_embedded_timeout(*timeout, async {
532536
tensorzero_core::endpoints::object_storage::get_object(
533-
&gateway.state.config,
537+
&gateway.handle.app_state.config,
534538
storage_path,
535539
)
536540
.await
@@ -568,7 +572,7 @@ impl Client {
568572
ClientMode::EmbeddedGateway { gateway, timeout } => {
569573
Ok(with_embedded_timeout(*timeout, async {
570574
tensorzero_core::endpoints::dynamic_evaluation_run::dynamic_evaluation_run(
571-
gateway.state.clone(),
575+
gateway.handle.app_state.clone(),
572576
params,
573577
)
574578
.await
@@ -598,7 +602,7 @@ impl Client {
598602
ClientMode::EmbeddedGateway { gateway, timeout } => {
599603
Ok(with_embedded_timeout(*timeout, async {
600604
tensorzero_core::endpoints::dynamic_evaluation_run::dynamic_evaluation_run_episode(
601-
gateway.state.clone(),
605+
gateway.handle.app_state.clone(),
602606
run_id,
603607
params,
604608
)
@@ -631,9 +635,9 @@ impl Client {
631635
tensorzero_core::endpoints::datasets::insert_datapoint(
632636
dataset_name,
633637
params,
634-
&gateway.state.config,
635-
&gateway.state.http_client,
636-
&gateway.state.clickhouse_connection_info,
638+
&gateway.handle.app_state.config,
639+
&gateway.handle.app_state.http_client,
640+
&gateway.handle.app_state.clickhouse_connection_info,
637641
)
638642
.await
639643
.map_err(err_to_http)
@@ -682,7 +686,7 @@ impl Client {
682686
tensorzero_core::endpoints::datasets::delete_datapoint(
683687
dataset_name,
684688
datapoint_id,
685-
&gateway.state.clickhouse_connection_info,
689+
&gateway.handle.app_state.clickhouse_connection_info,
686690
)
687691
.await
688692
.map_err(err_to_http)
@@ -720,7 +724,7 @@ impl Client {
720724
Ok(with_embedded_timeout(*timeout, async {
721725
tensorzero_core::endpoints::datasets::list_datapoints(
722726
dataset_name,
723-
&gateway.state.clickhouse_connection_info,
727+
&gateway.handle.app_state.clickhouse_connection_info,
724728
function_name,
725729
limit,
726730
offset,
@@ -754,7 +758,7 @@ impl Client {
754758
tensorzero_core::endpoints::datasets::get_datapoint(
755759
dataset_name,
756760
datapoint_id,
757-
&gateway.state.clickhouse_connection_info,
761+
&gateway.handle.app_state.clickhouse_connection_info,
758762
)
759763
.await
760764
.map_err(err_to_http)
@@ -775,7 +779,7 @@ impl Client {
775779
ClientMode::EmbeddedGateway { gateway, timeout } => {
776780
with_embedded_timeout(*timeout, async {
777781
tensorzero_core::endpoints::datasets::stale_dataset(
778-
&gateway.state.clickhouse_connection_info,
782+
&gateway.handle.app_state.clickhouse_connection_info,
779783
&dataset_name,
780784
)
781785
.await
@@ -824,9 +828,10 @@ impl Client {
824828
});
825829
};
826830
let inferences = gateway
827-
.state
831+
.handle
832+
.app_state
828833
.clickhouse_connection_info
829-
.list_inferences(&gateway.state.config, &params)
834+
.list_inferences(&gateway.handle.app_state.config, &params)
830835
.await
831836
.map_err(err_to_http)?;
832837
Ok(inferences)
@@ -854,9 +859,13 @@ impl Client {
854859
.into(),
855860
});
856861
};
857-
render_samples(gateway.state.config.clone(), stored_samples, variants)
858-
.await
859-
.map_err(err_to_http)
862+
render_samples(
863+
gateway.handle.app_state.config.clone(),
864+
stored_samples,
865+
variants,
866+
)
867+
.await
868+
.map_err(err_to_http)
860869
}
861870

862871
/// Launch an optimization job.
@@ -868,7 +877,7 @@ impl Client {
868877
ClientMode::EmbeddedGateway { gateway, timeout } => {
869878
// TODO: do we want this?
870879
Ok(with_embedded_timeout(*timeout, async {
871-
launch_optimization(&gateway.state.http_client, params)
880+
launch_optimization(&gateway.handle.app_state.http_client, params)
872881
.await
873882
.map_err(err_to_http)
874883
})
@@ -894,9 +903,9 @@ impl Client {
894903
ClientMode::EmbeddedGateway { gateway, timeout } => {
895904
with_embedded_timeout(*timeout, async {
896905
launch_optimization_workflow(
897-
&gateway.state.http_client,
898-
gateway.state.config.clone(),
899-
&gateway.state.clickhouse_connection_info,
906+
&gateway.handle.app_state.http_client,
907+
gateway.handle.app_state.config.clone(),
908+
&gateway.handle.app_state.clickhouse_connection_info,
900909
params,
901910
)
902911
.await
@@ -946,7 +955,7 @@ impl Client {
946955
ClientMode::EmbeddedGateway { gateway, timeout } => {
947956
Ok(with_embedded_timeout(*timeout, async {
948957
tensorzero_core::endpoints::optimization::poll_optimization(
949-
&gateway.state.http_client,
958+
&gateway.handle.app_state.http_client,
950959
job_handle,
951960
)
952961
.await
@@ -977,7 +986,9 @@ impl Client {
977986

978987
pub fn get_config(&self) -> Result<Arc<Config>, TensorZeroError> {
979988
match &self.mode {
980-
ClientMode::EmbeddedGateway { gateway, .. } => Ok(gateway.state.config.clone()),
989+
ClientMode::EmbeddedGateway { gateway, .. } => {
990+
Ok(gateway.handle.app_state.config.clone())
991+
}
981992
ClientMode::HTTPGateway(_) => Err(TensorZeroError::Other {
982993
source: tensorzero_core::error::Error::new(ErrorDetails::InvalidClientMode {
983994
mode: "Http".to_string(),
@@ -1152,9 +1163,9 @@ impl Client {
11521163
}
11531164

11541165
#[cfg(any(feature = "e2e_tests", feature = "pyo3"))]
1155-
pub fn get_app_state_data(&self) -> Option<&AppStateData> {
1166+
pub fn get_app_state_data(&self) -> Option<&tensorzero_core::gateway_util::AppStateData> {
11561167
match &self.mode {
1157-
ClientMode::EmbeddedGateway { gateway, .. } => Some(&gateway.state),
1168+
ClientMode::EmbeddedGateway { gateway, .. } => Some(&gateway.handle.app_state),
11581169
_ => None,
11591170
}
11601171
}

gateway/src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ async fn main() {
178178
}
179179

180180
// Initialize AppState
181-
let app_state = gateway_util::AppStateData::new(config.clone())
181+
let gateway_handle = gateway_util::GatewayHandle::new(config.clone())
182182
.await
183183
.expect_pretty("Failed to initialize AppState");
184-
setup_howdy(app_state.clickhouse_connection_info.clone());
184+
setup_howdy(gateway_handle.app_state.clickhouse_connection_info.clone());
185185

186186
// Create a new observability_enabled_pretty string for the log message below
187-
let observability_enabled_pretty = match &app_state.clickhouse_connection_info {
187+
let observability_enabled_pretty = match &gateway_handle.app_state.clickhouse_connection_info {
188188
ClickHouseConnectionInfo::Disabled => "disabled".to_string(),
189189
ClickHouseConnectionInfo::Mock { healthy, .. } => {
190190
format!("mocked (healthy={healthy})")
@@ -301,7 +301,7 @@ async fn main() {
301301
// OTEL exporting is done by the `OtelAxumLayer` above, which is only enabled for certain routes (and includes much more information)
302302
// We log failed requests messages at 'DEBUG', since we already have our own error-logging code,
303303
.layer(TraceLayer::new_for_http().on_failure(DefaultOnFailure::new().level(Level::DEBUG)))
304-
.with_state(app_state);
304+
.with_state(gateway_handle.app_state.clone());
305305

306306
// Bind to the socket address specified in the config, or default to 0.0.0.0:3000
307307
let bind_address = config

tensorzero-core/src/endpoints/feedback.rs

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ mod tests {
832832
use crate::config_parser::{Config, MetricConfig, MetricConfigOptimize};
833833
use crate::function::{FunctionConfigChat, FunctionConfigJson};
834834
use crate::jsonschema_util::StaticJSONSchema;
835-
use crate::testing::get_unit_test_app_state_data;
835+
use crate::testing::get_unit_test_gateway_handle;
836836
use crate::tool::{StaticToolConfig, ToolCallOutput, ToolChoice, ToolConfig};
837837

838838
#[tokio::test]
@@ -1012,7 +1012,7 @@ mod tests {
10121012
let config = Arc::new(Config {
10131013
..Default::default()
10141014
});
1015-
let app_state_data = get_unit_test_app_state_data(config, true);
1015+
let gateway_handle = get_unit_test_gateway_handle(config, true);
10161016
let timestamp = uuid::Timestamp::from_unix_time(1579751960, 0, 0, 0);
10171017
let episode_id = Uuid::new_v7(timestamp);
10181018
let value = json!("test comment");
@@ -1025,8 +1025,11 @@ mod tests {
10251025
internal: false,
10261026
dryrun: Some(false),
10271027
};
1028-
let response =
1029-
feedback_handler(State(app_state_data.clone()), StructuredJson(params)).await;
1028+
let response = feedback_handler(
1029+
State(gateway_handle.app_state.clone()),
1030+
StructuredJson(params),
1031+
)
1032+
.await;
10301033
let details = response.unwrap_err().get_owned_details();
10311034
assert_eq!(
10321035
details,
@@ -1041,7 +1044,7 @@ mod tests {
10411044
let config = Arc::new(Config {
10421045
..Default::default()
10431046
});
1044-
let app_state_data = get_unit_test_app_state_data(config, true);
1047+
let gateway_handle = get_unit_test_gateway_handle(config, true);
10451048
let timestamp = uuid::Timestamp::from_unix_time(1579751960, 0, 0, 0);
10461049
let episode_id = Uuid::new_v7(timestamp);
10471050
let value = json!("test demonstration");
@@ -1056,9 +1059,12 @@ mod tests {
10561059
dryrun: Some(false),
10571060
internal: false,
10581061
};
1059-
let response = feedback_handler(State(app_state_data.clone()), StructuredJson(params))
1060-
.await
1061-
.unwrap_err();
1062+
let response = feedback_handler(
1063+
State(gateway_handle.app_state.clone()),
1064+
StructuredJson(params),
1065+
)
1066+
.await
1067+
.unwrap_err();
10621068
let details = response.get_owned_details();
10631069
assert_eq!(
10641070
details,
@@ -1079,8 +1085,11 @@ mod tests {
10791085
dryrun: Some(false),
10801086
internal: false,
10811087
};
1082-
let response =
1083-
feedback_handler(State(app_state_data.clone()), StructuredJson(params)).await;
1088+
let response = feedback_handler(
1089+
State(gateway_handle.app_state.clone()),
1090+
StructuredJson(params),
1091+
)
1092+
.await;
10841093
let details = response.unwrap_err().get_owned_details();
10851094
assert_eq!(
10861095
details,
@@ -1105,7 +1114,7 @@ mod tests {
11051114
metrics,
11061115
..Default::default()
11071116
});
1108-
let app_state_data = get_unit_test_app_state_data(config.clone(), true);
1117+
let gateway_handle = get_unit_test_gateway_handle(config.clone(), true);
11091118
let value = json!(4.5);
11101119
let timestamp = uuid::Timestamp::from_unix_time(1579751960, 0, 0, 0);
11111120
let inference_id = Uuid::new_v7(timestamp);
@@ -1121,9 +1130,12 @@ mod tests {
11211130
dryrun: Some(false),
11221131
internal: false,
11231132
};
1124-
let response = feedback_handler(State(app_state_data.clone()), StructuredJson(params))
1125-
.await
1126-
.unwrap_err();
1133+
let response = feedback_handler(
1134+
State(gateway_handle.app_state.clone()),
1135+
StructuredJson(params),
1136+
)
1137+
.await
1138+
.unwrap_err();
11271139
let details = response.get_owned_details();
11281140
assert_eq!(
11291141
details,
@@ -1142,8 +1154,11 @@ mod tests {
11421154
dryrun: Some(false),
11431155
internal: false,
11441156
};
1145-
let response =
1146-
feedback_handler(State(app_state_data.clone()), StructuredJson(params)).await;
1157+
let response = feedback_handler(
1158+
State(gateway_handle.app_state.clone()),
1159+
StructuredJson(params),
1160+
)
1161+
.await;
11471162
let details = response.unwrap_err().get_owned_details();
11481163
assert_eq!(
11491164
details,
@@ -1168,7 +1183,7 @@ mod tests {
11681183
metrics,
11691184
..Default::default()
11701185
});
1171-
let app_state_data = get_unit_test_app_state_data(config.clone(), true);
1186+
let gateway_handle = get_unit_test_gateway_handle(config.clone(), true);
11721187
let value = json!(true);
11731188
let timestamp = uuid::Timestamp::from_unix_time(1579751960, 0, 0, 0);
11741189
let inference_id = Uuid::new_v7(timestamp);
@@ -1181,8 +1196,11 @@ mod tests {
11811196
dryrun: None,
11821197
internal: false,
11831198
};
1184-
let response =
1185-
feedback_handler(State(app_state_data.clone()), StructuredJson(params)).await;
1199+
let response = feedback_handler(
1200+
State(gateway_handle.app_state.clone()),
1201+
StructuredJson(params),
1202+
)
1203+
.await;
11861204
let details = response.unwrap_err().get_owned_details();
11871205
assert_eq!(
11881206
details,

0 commit comments

Comments
 (0)