diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index ea8d07feb..af4a078d6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -582,7 +582,7 @@ def chat_completion_handler( if result.stopping_criteria is not None: stopping_criteria = result.stopping_criteria - if response_format is not None and response_format["type"] == "json_object": + if response_format is not None: grammar = _grammar_for_response_format( response_format, verbose=llama.verbose ) @@ -928,6 +928,13 @@ def _grammar_for_response_format( response_format: llama_types.ChatCompletionRequestResponseFormat, verbose: bool = False, ): + + # convert openai type 'json_schema' to llama_cpp type 'json_object': + if response_format['type'] == "json_schema": + response_format['type'] = "json_object" + response_format['schema'] = response_format['json_schema']['schema'] + del response_format['json_schema'] + if response_format["type"] != "json_object": return None @@ -2830,7 +2837,7 @@ def embed_image_bytes(image_bytes: bytes): # Get prompt tokens to avoid a cache miss prompt = llama.input_ids[: llama.n_tokens].tolist() - if response_format is not None and response_format["type"] == "json_object": + if response_format is not None: grammar = _grammar_for_response_format(response_format) # Convert legacy functions to tools @@ -3442,7 +3449,7 @@ def chatml_function_calling( add_generation_prompt=True, ) - if response_format is not None and response_format["type"] == "json_object": + if response_format is not None: grammar = _grammar_for_response_format(response_format) return _convert_completion_to_chat( diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index bbb58afc3..eb6d593e6 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -156,10 +156,13 @@ class ChatCompletionFunctionCallOption(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict): - type: Literal["text", "json_object"] + type: Literal["text", "json_object", "json_schema"] schema: NotRequired[ JsonType ] # https://docs.endpoints.anyscale.com/guides/json_mode/ + json_schema: NotRequired[ + JsonType + ] class ChatCompletionRequestMessageContentPartText(TypedDict):