@@ -26,7 +26,7 @@ use tensorzero_core::{
26
26
validate_tags,
27
27
} ,
28
28
error:: { Error , ErrorDetails } ,
29
- gateway_util:: { setup_clickhouse, setup_http_client, AppStateData } ,
29
+ gateway_util:: { setup_clickhouse, setup_http_client, GatewayHandle } ,
30
30
} ;
31
31
use thiserror:: Error ;
32
32
use tokio:: { sync:: Mutex , time:: error:: Elapsed } ;
@@ -122,7 +122,7 @@ impl HTTPGateway {
122
122
}
123
123
124
124
struct EmbeddedGateway {
125
- state : AppStateData ,
125
+ handle : GatewayHandle ,
126
126
}
127
127
128
128
/// Used to construct a `Client`
@@ -301,7 +301,7 @@ impl ClientBuilder {
301
301
Ok ( Client {
302
302
mode : ClientMode :: EmbeddedGateway {
303
303
gateway : EmbeddedGateway {
304
- state : AppStateData :: new_with_clickhouse_and_http_client (
304
+ handle : GatewayHandle :: new_with_clickhouse_and_http_client (
305
305
config,
306
306
clickhouse_connection_info,
307
307
http_client,
@@ -318,10 +318,10 @@ impl ClientBuilder {
318
318
}
319
319
320
320
#[ cfg( any( test, feature = "e2e_tests" ) ) ]
321
- pub async fn build_from_state ( state : AppStateData ) -> Result < Client , ClientBuilderError > {
321
+ pub async fn build_from_state ( handle : GatewayHandle ) -> Result < Client , ClientBuilderError > {
322
322
Ok ( Client {
323
323
mode : ClientMode :: EmbeddedGateway {
324
- gateway : EmbeddedGateway { state } ,
324
+ gateway : EmbeddedGateway { handle } ,
325
325
timeout : None ,
326
326
} ,
327
327
verbose_errors : false ,
@@ -374,7 +374,8 @@ impl Client {
374
374
gateway,
375
375
timeout : _,
376
376
} => gateway
377
- . state
377
+ . handle
378
+ . app_state
378
379
. clickhouse_connection_info
379
380
. health ( )
380
381
. await
@@ -406,9 +407,12 @@ impl Client {
406
407
}
407
408
ClientMode :: EmbeddedGateway { gateway, timeout } => {
408
409
Ok ( with_embedded_timeout ( * timeout, async {
409
- tensorzero_core:: endpoints:: feedback:: feedback ( gateway. state . clone ( ) , params)
410
- . await
411
- . map_err ( err_to_http)
410
+ tensorzero_core:: endpoints:: feedback:: feedback (
411
+ gateway. handle . app_state . clone ( ) ,
412
+ params,
413
+ )
414
+ . await
415
+ . map_err ( err_to_http)
412
416
} )
413
417
. await ?)
414
418
}
@@ -482,9 +486,9 @@ impl Client {
482
486
ClientMode :: EmbeddedGateway { gateway, timeout } => {
483
487
Ok ( with_embedded_timeout ( * timeout, async {
484
488
tensorzero_core:: endpoints:: inference:: inference (
485
- gateway. state . config . clone ( ) ,
486
- & gateway. state . http_client ,
487
- gateway. state . clickhouse_connection_info . clone ( ) ,
489
+ gateway. handle . app_state . config . clone ( ) ,
490
+ & gateway. handle . app_state . http_client ,
491
+ gateway. handle . app_state . clickhouse_connection_info . clone ( ) ,
488
492
params. try_into ( ) . map_err ( err_to_http) ?,
489
493
)
490
494
. await
@@ -530,7 +534,7 @@ impl Client {
530
534
ClientMode :: EmbeddedGateway { gateway, timeout } => {
531
535
Ok ( with_embedded_timeout ( * timeout, async {
532
536
tensorzero_core:: endpoints:: object_storage:: get_object (
533
- & gateway. state . config ,
537
+ & gateway. handle . app_state . config ,
534
538
storage_path,
535
539
)
536
540
. await
@@ -568,7 +572,7 @@ impl Client {
568
572
ClientMode :: EmbeddedGateway { gateway, timeout } => {
569
573
Ok ( with_embedded_timeout ( * timeout, async {
570
574
tensorzero_core:: endpoints:: dynamic_evaluation_run:: dynamic_evaluation_run (
571
- gateway. state . clone ( ) ,
575
+ gateway. handle . app_state . clone ( ) ,
572
576
params,
573
577
)
574
578
. await
@@ -598,7 +602,7 @@ impl Client {
598
602
ClientMode :: EmbeddedGateway { gateway, timeout } => {
599
603
Ok ( with_embedded_timeout ( * timeout, async {
600
604
tensorzero_core:: endpoints:: dynamic_evaluation_run:: dynamic_evaluation_run_episode (
601
- gateway. state . clone ( ) ,
605
+ gateway. handle . app_state . clone ( ) ,
602
606
run_id,
603
607
params,
604
608
)
@@ -631,9 +635,9 @@ impl Client {
631
635
tensorzero_core:: endpoints:: datasets:: insert_datapoint (
632
636
dataset_name,
633
637
params,
634
- & gateway. state . config ,
635
- & gateway. state . http_client ,
636
- & gateway. state . clickhouse_connection_info ,
638
+ & gateway. handle . app_state . config ,
639
+ & gateway. handle . app_state . http_client ,
640
+ & gateway. handle . app_state . clickhouse_connection_info ,
637
641
)
638
642
. await
639
643
. map_err ( err_to_http)
@@ -682,7 +686,7 @@ impl Client {
682
686
tensorzero_core:: endpoints:: datasets:: delete_datapoint (
683
687
dataset_name,
684
688
datapoint_id,
685
- & gateway. state . clickhouse_connection_info ,
689
+ & gateway. handle . app_state . clickhouse_connection_info ,
686
690
)
687
691
. await
688
692
. map_err ( err_to_http)
@@ -720,7 +724,7 @@ impl Client {
720
724
Ok ( with_embedded_timeout ( * timeout, async {
721
725
tensorzero_core:: endpoints:: datasets:: list_datapoints (
722
726
dataset_name,
723
- & gateway. state . clickhouse_connection_info ,
727
+ & gateway. handle . app_state . clickhouse_connection_info ,
724
728
function_name,
725
729
limit,
726
730
offset,
@@ -754,7 +758,7 @@ impl Client {
754
758
tensorzero_core:: endpoints:: datasets:: get_datapoint (
755
759
dataset_name,
756
760
datapoint_id,
757
- & gateway. state . clickhouse_connection_info ,
761
+ & gateway. handle . app_state . clickhouse_connection_info ,
758
762
)
759
763
. await
760
764
. map_err ( err_to_http)
@@ -775,7 +779,7 @@ impl Client {
775
779
ClientMode :: EmbeddedGateway { gateway, timeout } => {
776
780
with_embedded_timeout ( * timeout, async {
777
781
tensorzero_core:: endpoints:: datasets:: stale_dataset (
778
- & gateway. state . clickhouse_connection_info ,
782
+ & gateway. handle . app_state . clickhouse_connection_info ,
779
783
& dataset_name,
780
784
)
781
785
. await
@@ -824,9 +828,10 @@ impl Client {
824
828
} ) ;
825
829
} ;
826
830
let inferences = gateway
827
- . state
831
+ . handle
832
+ . app_state
828
833
. clickhouse_connection_info
829
- . list_inferences ( & gateway. state . config , & params)
834
+ . list_inferences ( & gateway. handle . app_state . config , & params)
830
835
. await
831
836
. map_err ( err_to_http) ?;
832
837
Ok ( inferences)
@@ -854,9 +859,13 @@ impl Client {
854
859
. into ( ) ,
855
860
} ) ;
856
861
} ;
857
- render_samples ( gateway. state . config . clone ( ) , stored_samples, variants)
858
- . await
859
- . map_err ( err_to_http)
862
+ render_samples (
863
+ gateway. handle . app_state . config . clone ( ) ,
864
+ stored_samples,
865
+ variants,
866
+ )
867
+ . await
868
+ . map_err ( err_to_http)
860
869
}
861
870
862
871
/// Launch an optimization job.
@@ -868,7 +877,7 @@ impl Client {
868
877
ClientMode :: EmbeddedGateway { gateway, timeout } => {
869
878
// TODO: do we want this?
870
879
Ok ( with_embedded_timeout ( * timeout, async {
871
- launch_optimization ( & gateway. state . http_client , params)
880
+ launch_optimization ( & gateway. handle . app_state . http_client , params)
872
881
. await
873
882
. map_err ( err_to_http)
874
883
} )
@@ -894,9 +903,9 @@ impl Client {
894
903
ClientMode :: EmbeddedGateway { gateway, timeout } => {
895
904
with_embedded_timeout ( * timeout, async {
896
905
launch_optimization_workflow (
897
- & gateway. state . http_client ,
898
- gateway. state . config . clone ( ) ,
899
- & gateway. state . clickhouse_connection_info ,
906
+ & gateway. handle . app_state . http_client ,
907
+ gateway. handle . app_state . config . clone ( ) ,
908
+ & gateway. handle . app_state . clickhouse_connection_info ,
900
909
params,
901
910
)
902
911
. await
@@ -946,7 +955,7 @@ impl Client {
946
955
ClientMode :: EmbeddedGateway { gateway, timeout } => {
947
956
Ok ( with_embedded_timeout ( * timeout, async {
948
957
tensorzero_core:: endpoints:: optimization:: poll_optimization (
949
- & gateway. state . http_client ,
958
+ & gateway. handle . app_state . http_client ,
950
959
job_handle,
951
960
)
952
961
. await
@@ -977,7 +986,9 @@ impl Client {
977
986
978
987
pub fn get_config ( & self ) -> Result < Arc < Config > , TensorZeroError > {
979
988
match & self . mode {
980
- ClientMode :: EmbeddedGateway { gateway, .. } => Ok ( gateway. state . config . clone ( ) ) ,
989
+ ClientMode :: EmbeddedGateway { gateway, .. } => {
990
+ Ok ( gateway. handle . app_state . config . clone ( ) )
991
+ }
981
992
ClientMode :: HTTPGateway ( _) => Err ( TensorZeroError :: Other {
982
993
source : tensorzero_core:: error:: Error :: new ( ErrorDetails :: InvalidClientMode {
983
994
mode : "Http" . to_string ( ) ,
@@ -1152,9 +1163,9 @@ impl Client {
1152
1163
}
1153
1164
1154
1165
#[ cfg( any( feature = "e2e_tests" , feature = "pyo3" ) ) ]
1155
- pub fn get_app_state_data ( & self ) -> Option < & AppStateData > {
1166
+ pub fn get_app_state_data ( & self ) -> Option < & tensorzero_core :: gateway_util :: AppStateData > {
1156
1167
match & self . mode {
1157
- ClientMode :: EmbeddedGateway { gateway, .. } => Some ( & gateway. state ) ,
1168
+ ClientMode :: EmbeddedGateway { gateway, .. } => Some ( & gateway. handle . app_state ) ,
1158
1169
_ => None ,
1159
1170
}
1160
1171
}
0 commit comments