@@ -33,6 +33,7 @@ def __init__(self, config=None):
33
33
self .message = MessageManager (self .config )
34
34
self .current_user = None
35
35
self .override_provider = None
36
+ self .override_preset = None
36
37
self .override_llm = None
37
38
self .plugin_manager = PluginManager (self .config , self , additional_plugins = ADDITIONAL_PLUGINS )
38
39
self .provider_manager = ProviderManager (self .config , self .plugin_manager )
@@ -135,17 +136,19 @@ def set_override_llm(self, preset_name=None):
135
136
customizations = self .expand_functions (customizations )
136
137
success , provider , user_message = self .provider_manager .load_provider (metadata ['provider' ])
137
138
if success :
138
- self .override_provider = provider
139
139
if self .stream and self .should_stream ():
140
140
self .log .debug ("Adding streaming-specific customizations to LLM request" )
141
141
customizations .update (self .streaming_args (interrupt_handler = True ))
142
142
self .override_llm = provider .make_llm (customizations , use_defaults = True )
143
+ self .override_provider = provider
144
+ self .override_preset = preset
143
145
message = f"Set override LLM based on preset { preset_name } "
144
146
self .log .debug (message )
145
147
return True , self .override_llm , message
146
148
return False , None , user_message
147
149
else :
148
150
self .log .debug ("Unsetting override LLM" )
151
+ self .override_preset = None
149
152
self .override_provider = None
150
153
self .override_llm = None
151
154
message = "Unset override LLM"
@@ -329,8 +332,9 @@ def extract_system_message_from_overrides(self, request_overrides):
329
332
return system_message , request_overrides
330
333
331
334
def should_return_on_function_call (self ):
332
- if self .active_preset :
333
- metadata , _customizations = self .active_preset
335
+ preset = self .override_preset or self .active_preset
336
+ if preset :
337
+ metadata , _customizations = preset
334
338
if 'return_on_function_call' in metadata and metadata ['return_on_function_call' ]:
335
339
return True
336
340
return False
@@ -339,8 +343,9 @@ def is_function_response_message(self, message):
339
343
return message ['message_type' ] == 'function_response'
340
344
341
345
def check_return_on_function_response (self , new_messages ):
342
- if self .active_preset :
343
- metadata , _customizations = self .active_preset
346
+ preset = self .override_preset or self .active_preset
347
+ if preset :
348
+ metadata , _customizations = preset
344
349
if 'return_on_function_response' in metadata and metadata ['return_on_function_response' ]:
345
350
# NOTE: In order to allow for multiple function calling and
346
351
# returning on the LAST function response, we need to allow
0 commit comments