@@ -348,7 +348,7 @@ def get_param_type(self, name: str) -> Optional[Type[Any]]:
348
348
return result .param_type
349
349
350
350
351
- class SendedRequestEntry :
351
+ class SentRequestEntry :
352
352
def __init__ (self , future : Task [Any ], result_type : Optional [Type [Any ]]) -> None :
353
353
self .future = future
354
354
self .result_type = result_type
@@ -420,6 +420,8 @@ def eof_received(self) -> Optional[bool]:
420
420
re .DOTALL ,
421
421
)
422
422
423
+ DEFAULT_MESSAGE_LENGTH : Final = 1
424
+
423
425
def data_received (self , data : bytes ) -> None :
424
426
while len (data ):
425
427
# Append the incoming chunk to the message buffer
@@ -429,7 +431,7 @@ def data_received(self, data: bytes) -> None:
429
431
found = self .MESSAGE_PATTERN .match (self ._message_buf )
430
432
431
433
body = found .group ("body" ) if found else b""
432
- length = int (found .group ("length" )) if found else 1
434
+ length = int (found .group ("length" )) if found else self . DEFAULT_MESSAGE_LENGTH
433
435
434
436
charset = (
435
437
found .group ("charset" ).decode ("ascii" ) if found and found .group ("charset" ) is not None else self .CHARSET
@@ -453,9 +455,9 @@ class JsonRPCProtocol(JsonRPCProtocolBase):
453
455
454
456
def __init__ (self ) -> None :
455
457
super ().__init__ ()
456
- self ._sended_request_lock = threading .RLock ()
457
- self ._sended_request : OrderedDict [Union [str , int ], SendedRequestEntry ] = OrderedDict ()
458
- self ._sended_request_count = 0
458
+ self ._sent_request_lock = threading .RLock ()
459
+ self ._sent_request : OrderedDict [Union [str , int ], SentRequestEntry ] = OrderedDict ()
460
+ self ._sent_request_count = 0
459
461
self ._received_request : OrderedDict [Union [str , int , None ], ReceivedRequestEntry ] = OrderedDict ()
460
462
self ._received_request_lock = threading .RLock ()
461
463
self ._signature_cache : Dict [Callable [..., Any ], inspect .Signature ] = {}
@@ -567,11 +569,11 @@ def send_request(
567
569
) -> Task [_TResult ]:
568
570
result : Task [_TResult ] = Task ()
569
571
570
- with self ._sended_request_lock :
571
- self ._sended_request_count += 1
572
- id = self ._sended_request_count
572
+ with self ._sent_request_lock :
573
+ self ._sent_request_count += 1
574
+ id = self ._sent_request_count
573
575
574
- self ._sended_request [id ] = SendedRequestEntry (result , return_type )
576
+ self ._sent_request [id ] = SentRequestEntry (result , return_type )
575
577
576
578
request = JsonRPCRequest (id = id , method = method , params = params )
577
579
self .send_message (request )
@@ -598,8 +600,8 @@ async def handle_response(self, message: JsonRPCResponse) -> None:
598
600
self .send_error (JsonRPCErrors .INTERNAL_ERROR , error )
599
601
return
600
602
601
- with self ._sended_request_lock :
602
- entry = self ._sended_request .pop (message .id , None )
603
+ with self ._sent_request_lock :
604
+ entry = self ._sent_request .pop (message .id , None )
603
605
604
606
if entry is None :
605
607
error = f"Invalid response. Could not find id '{ message .id } ' in request list."
@@ -628,8 +630,8 @@ async def handle_error(self, message: JsonRPCError) -> None:
628
630
self .__logger .warning (error )
629
631
raise JsonRPCErrorException (message .error .code , message .error .message , message .error .data )
630
632
631
- with self ._sended_request_lock :
632
- entry = self ._sended_request .pop (message .id , None )
633
+ with self ._sent_request_lock :
634
+ entry = self ._sent_request .pop (message .id , None )
633
635
634
636
if entry is None :
635
637
error = f"Invalid response. Could not find id '{ message .id } ' in request list."
@@ -660,71 +662,158 @@ def _convert_params(
660
662
params_type : Optional [Type [Any ]],
661
663
params : Any ,
662
664
) -> Tuple [List [Any ], Dict [str , Any ]]:
665
+ """Convert JSON-RPC parameters to function arguments.
666
+
667
+ Args:
668
+ callable: The target function to call
669
+ params_type: Expected parameter type for conversion
670
+ params: Raw parameters from JSON-RPC message
671
+
672
+ Returns:
673
+ Tuple of (positional_args, keyword_args) for function call
674
+ """
663
675
if params is None :
664
676
return [], {}
677
+
665
678
if params_type is None :
666
- if isinstance (params , Mapping ):
667
- return [], dict (** params )
679
+ return self ._handle_untyped_params (params )
680
+
681
+ return self ._handle_typed_params (callable , params_type , params )
668
682
669
- return [params ], {}
683
+ def _handle_untyped_params (self , params : Any ) -> Tuple [List [Any ], Dict [str , Any ]]:
684
+ """Handle parameters when no specific type is expected."""
685
+ if isinstance (params , Mapping ):
686
+ return [], dict (** params )
687
+ return [params ], {}
670
688
671
- # try to convert the dict to correct type
689
+ def _handle_typed_params (
690
+ self ,
691
+ callable : Callable [..., Any ],
692
+ params_type : Type [Any ],
693
+ params : Any ,
694
+ ) -> Tuple [List [Any ], Dict [str , Any ]]:
695
+ """Handle parameters with type conversion and signature matching."""
696
+ # Convert the parameters to the expected type
672
697
converted_params = from_dict (params , params_type )
673
698
674
- # get the signature of the callable
699
+ # Get cached signature or create new one
700
+ signature = self ._get_cached_signature (callable )
701
+
702
+ # Extract field names from converted parameters
703
+ field_names = self ._extract_field_names (converted_params )
704
+
705
+ # Process signature parameters
706
+ return self ._process_signature_parameters (signature , converted_params , params , field_names )
707
+
708
+ def _get_cached_signature (self , callable : Callable [..., Any ]) -> inspect .Signature :
709
+ """Get or cache the signature of a callable."""
675
710
if callable in self ._signature_cache :
676
- signature = self ._signature_cache [callable ]
677
- else :
678
- signature = inspect .signature (callable )
679
- self ._signature_cache [callable ] = signature
711
+ return self ._signature_cache [callable ]
712
+
713
+ signature = inspect .signature (callable )
714
+ self ._signature_cache [callable ] = signature
715
+ return signature
680
716
717
+ def _extract_field_names (self , converted_params : Any ) -> List [str ]:
718
+ """Extract field names from converted parameters."""
719
+ if is_dataclass (converted_params ):
720
+ return [f .name for f in fields (converted_params )]
721
+ return list (converted_params .__dict__ .keys ())
722
+
723
+ def _process_signature_parameters (
724
+ self ,
725
+ signature : inspect .Signature ,
726
+ converted_params : Any ,
727
+ params : Any ,
728
+ field_names : List [str ],
729
+ ) -> Tuple [List [Any ], Dict [str , Any ]]:
730
+ """Process function signature parameters and map them to arguments."""
681
731
has_var_kw = any (p .kind == inspect .Parameter .VAR_KEYWORD for p in signature .parameters .values ())
682
732
683
- kw_args = {}
684
- args = []
733
+ kw_args : Dict [ str , Any ] = {}
734
+ args : List [ Any ] = []
685
735
params_added = False
686
736
687
- field_names = (
688
- [f .name for f in fields (converted_params )]
689
- if is_dataclass (converted_params )
690
- else list (converted_params .__dict__ .keys ())
691
- )
692
-
693
737
rest = set (field_names )
694
738
if isinstance (params , dict ):
695
739
rest = set .union (rest , params .keys ())
696
740
697
- for v in signature .parameters .values ():
698
- if v .name in field_names :
699
- if v .kind == inspect .Parameter .POSITIONAL_ONLY :
700
- args .append (getattr (converted_params , v .name ))
701
- else :
702
- kw_args [v .name ] = getattr (converted_params , v .name )
703
-
704
- rest .remove (v .name )
705
- elif v .name == "params" :
706
- if v .kind == inspect .Parameter .POSITIONAL_ONLY :
707
- args .append (converted_params )
708
- params_added = True
709
- else :
710
- kw_args [v .name ] = converted_params
711
- params_added = True
712
- elif isinstance (params , dict ) and v .name in params :
713
- if v .kind == inspect .Parameter .POSITIONAL_ONLY :
714
- args .append (params [v .name ])
715
- else :
716
- kw_args [v .name ] = params [v .name ]
741
+ # Map signature parameters to arguments
742
+ for param in signature .parameters .values ():
743
+ if param .name in field_names :
744
+ self ._add_field_parameter (param , converted_params , args , kw_args )
745
+ rest .remove (param .name )
746
+ elif param .name == "params" :
747
+ self ._add_params_parameter (param , converted_params , args , kw_args )
748
+ params_added = True
749
+ elif isinstance (params , dict ) and param .name in params :
750
+ self ._add_dict_parameter (param , params , args , kw_args )
751
+
752
+ # Handle remaining parameters if function accepts **kwargs
717
753
if has_var_kw :
718
- for r in rest :
719
- if hasattr (converted_params , r ):
720
- kw_args [r ] = getattr (converted_params , r )
721
- elif isinstance (params , dict ) and r in params :
722
- kw_args [r ] = params [r ]
723
-
724
- if not params_added :
725
- kw_args ["params" ] = converted_params
754
+ self ._handle_var_keywords (rest , converted_params , params , kw_args , params_added )
755
+
726
756
return args , kw_args
727
757
758
+ def _add_field_parameter (
759
+ self ,
760
+ param : inspect .Parameter ,
761
+ converted_params : Any ,
762
+ args : List [Any ],
763
+ kw_args : Dict [str , Any ],
764
+ ) -> None :
765
+ """Add a parameter from converted_params fields."""
766
+ value = getattr (converted_params , param .name )
767
+ if param .kind == inspect .Parameter .POSITIONAL_ONLY :
768
+ args .append (value )
769
+ else :
770
+ kw_args [param .name ] = value
771
+
772
+ def _add_params_parameter (
773
+ self ,
774
+ param : inspect .Parameter ,
775
+ converted_params : Any ,
776
+ args : List [Any ],
777
+ kw_args : Dict [str , Any ],
778
+ ) -> None :
779
+ """Add the entire converted_params as 'params' parameter."""
780
+ if param .kind == inspect .Parameter .POSITIONAL_ONLY :
781
+ args .append (converted_params )
782
+ else :
783
+ kw_args [param .name ] = converted_params
784
+
785
+ def _add_dict_parameter (
786
+ self ,
787
+ param : inspect .Parameter ,
788
+ params : Dict [str , Any ],
789
+ args : List [Any ],
790
+ kw_args : Dict [str , Any ],
791
+ ) -> None :
792
+ """Add a parameter from the original params dict."""
793
+ value = params [param .name ]
794
+ if param .kind == inspect .Parameter .POSITIONAL_ONLY :
795
+ args .append (value )
796
+ else :
797
+ kw_args [param .name ] = value
798
+
799
+ def _handle_var_keywords (
800
+ self ,
801
+ rest : Set [str ],
802
+ converted_params : Any ,
803
+ params : Any ,
804
+ kw_args : Dict [str , Any ],
805
+ params_added : bool ,
806
+ ) -> None :
807
+ """Handle remaining parameters for functions with **kwargs."""
808
+ for name in rest :
809
+ if hasattr (converted_params , name ):
810
+ kw_args [name ] = getattr (converted_params , name )
811
+ elif isinstance (params , dict ) and name in params :
812
+ kw_args [name ] = params [name ]
813
+
814
+ if not params_added :
815
+ kw_args ["params" ] = converted_params
816
+
728
817
async def handle_request (self , message : JsonRPCRequest ) -> None :
729
818
try :
730
819
e = self .registry .get_entry (message .method )
@@ -859,7 +948,14 @@ async def handle_notification(self, message: JsonRPCNotification) -> None:
859
948
pass
860
949
except (SystemExit , KeyboardInterrupt ):
861
950
raise
951
+ except JsonRPCErrorException :
952
+ # Specific RPC errors should be re-raised
953
+ raise
954
+ except (ValueError , TypeError ) as e :
955
+ # Parameter validation errors
956
+ self .__logger .warning (lambda : f"Parameter validation failed for { message .method } : { e } " , exc_info = e )
862
957
except BaseException as e :
958
+ # Unexpected errors
863
959
self .__logger .exception (e )
864
960
865
961
0 commit comments