Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def metadata(self) -> Dict[str, str]:
for i in range(llama_cpp.llama_model_meta_count(self.model)):
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
key = buffer.value.decode("utf-8")
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
value = buffer.value.decode("utf-8")
Expand Down
37 changes: 36 additions & 1 deletion llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
# Backend Params
numa: bool = False,
# Chat Format Params
chat_format: str = "llama-2",
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Misc
verbose: bool = True,
Expand Down Expand Up @@ -343,6 +343,41 @@ def __init__(
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)

if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)

if chat_format is not None:
self.chat_format = chat_format
if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else:
template = self.metadata["tokenizer.chat_template"]
try:
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
except:
eos_token_id = self.token_eos()
try:
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
except:
bos_token_id = self.token_bos()

eos_token = self.detokenize([eos_token_id]).decode("utf-8")
bos_token = self.detokenize([bos_token_id]).decode("utf-8")

if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)

self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
).to_chat_handler()

if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"

@property
def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None
Expand Down
30 changes: 28 additions & 2 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@

from ._utils import suppress_stdout_stderr, Singleton

### Common Chat Templates and Special Tokens ###

# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_BOS_TOKEN = "<s>"
CHATML_EOS_TOKEN = "<|im_end|>"

# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"


### Chat Completion Handler ###

class LlamaChatCompletionHandler(Protocol):
"""Base Protocol for a llama chat completion handler.
Expand Down Expand Up @@ -118,7 +132,6 @@ def decorator(f: LlamaChatCompletionHandler):

### Chat Formatter ###


@dataclasses.dataclass
class ChatFormatterResponse:
"""Dataclass that stores completion parameters for a given chat format and
Expand Down Expand Up @@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter)


def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
if "tokenizer.chat_template" not in metadata:
return None

if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
return "chatml"

if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
return "mistral-instruct"

return None

### Utility functions for formatting chat prompts ###
# TODO: Replace these with jinja2 templates


def _get_system_message(
Expand Down Expand Up @@ -929,7 +955,6 @@ def format_openchat(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)


# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format("saiga")
Expand All @@ -951,6 +976,7 @@ def format_saiga(
_prompt += "<s>bot"
return ChatFormatterResponse(prompt=_prompt.strip())

# Tricky chat formats that require custom chat handlers

@register_chat_completion_handler("functionary")
def functionary_chat_handler(
Expand Down
4 changes: 2 additions & 2 deletions llama_cpp/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
chat_format: Optional[str] = Field(
default=None,
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(
Expand Down