Skip to content

Commit 65d9cc0

Browse files
committed
Add openai frequency and presence penalty parameters. Closes abetlen#169
1 parent 75d8619 commit 65d9cc0

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

llama_cpp/llama.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,14 +261,16 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
261261
]
262262
self.eval_logits.extend(logits)
263263

264-
def _sample_top_p_top_k(
264+
def _sample(
265265
self,
266266
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
267267
last_n_tokens_size: llama_cpp.c_int,
268268
top_k: llama_cpp.c_int,
269269
top_p: llama_cpp.c_float,
270270
temp: llama_cpp.c_float,
271271
repeat_penalty: llama_cpp.c_float,
272+
frequency_penalty: llama_cpp.c_float,
273+
presence_penalty: llama_cpp.c_float,
272274
):
273275
assert self.ctx is not None
274276
assert len(self.eval_logits) > 0
@@ -298,6 +300,14 @@ def _sample_top_p_top_k(
298300
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
299301
penalty=repeat_penalty,
300302
)
303+
llama_cpp.llama_sample_frequency_and_presence_penalties(
304+
ctx=self.ctx,
305+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
306+
last_tokens_data=last_n_tokens_data,
307+
last_tokens_size=last_n_tokens_size,
308+
alpha_frequency=frequency_penalty,
309+
alpha_presence=presence_penalty,
310+
)
301311
if float(temp.value) == 0.0:
302312
return llama_cpp.llama_sample_token_greedy(
303313
ctx=self.ctx,
@@ -344,6 +354,8 @@ def sample(
344354
top_p: float,
345355
temp: float,
346356
repeat_penalty: float,
357+
frequency_penalty: float = 0.0,
358+
presence_penalty: float = 0.0,
347359
):
348360
"""Sample a token from the model.
349361
@@ -360,7 +372,7 @@ def sample(
360372
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
361373
0, self.last_n_tokens_size - len(self.eval_tokens)
362374
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
363-
return self._sample_top_p_top_k(
375+
return self._sample(
364376
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
365377
*last_n_tokens_data
366378
),
@@ -369,6 +381,8 @@ def sample(
369381
top_p=llama_cpp.c_float(top_p),
370382
temp=llama_cpp.c_float(temp),
371383
repeat_penalty=llama_cpp.c_float(repeat_penalty),
384+
frequency_penalty=llama_cpp.c_float(frequency_penalty),
385+
presence_penalty=llama_cpp.c_float(presence_penalty),
372386
)
373387

374388
def generate(
@@ -378,6 +392,8 @@ def generate(
378392
top_p: float,
379393
temp: float,
380394
repeat_penalty: float,
395+
frequency_penalty: float = 0.0,
396+
presence_penalty: float = 0.0,
381397
reset: bool = True,
382398
) -> Generator[
383399
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
@@ -431,6 +447,8 @@ def generate(
431447
top_k=top_k,
432448
top_p=top_p,
433449
temp=temp,
450+
frequency_penalty=frequency_penalty,
451+
presence_penalty=presence_penalty,
434452
repeat_penalty=repeat_penalty,
435453
)
436454
tokens_or_none = yield token
@@ -505,6 +523,8 @@ def _create_completion(
505523
logprobs: Optional[int] = None,
506524
echo: bool = False,
507525
stop: Optional[List[str]] = [],
526+
frequency_penalty: float = 0.0,
527+
presence_penalty: float = 0.0,
508528
repeat_penalty: float = 1.1,
509529
top_k: int = 40,
510530
stream: bool = False,
@@ -563,6 +583,8 @@ def _create_completion(
563583
top_k=top_k,
564584
top_p=top_p,
565585
temp=temperature,
586+
frequency_penalty=frequency_penalty,
587+
presence_penalty=presence_penalty,
566588
repeat_penalty=repeat_penalty,
567589
):
568590
if token == llama_cpp.llama_token_eos():
@@ -737,6 +759,8 @@ def create_completion(
737759
logprobs: Optional[int] = None,
738760
echo: bool = False,
739761
stop: Optional[List[str]] = [],
762+
frequency_penalty: float = 0.0,
763+
presence_penalty: float = 0.0,
740764
repeat_penalty: float = 1.1,
741765
top_k: int = 40,
742766
stream: bool = False,
@@ -772,6 +796,8 @@ def create_completion(
772796
logprobs=logprobs,
773797
echo=echo,
774798
stop=stop,
799+
frequency_penalty=frequency_penalty,
800+
presence_penalty=presence_penalty,
775801
repeat_penalty=repeat_penalty,
776802
top_k=top_k,
777803
stream=stream,
@@ -792,6 +818,8 @@ def __call__(
792818
logprobs: Optional[int] = None,
793819
echo: bool = False,
794820
stop: Optional[List[str]] = [],
821+
frequency_penalty: float = 0.0,
822+
presence_penalty: float = 0.0,
795823
repeat_penalty: float = 1.1,
796824
top_k: int = 40,
797825
stream: bool = False,
@@ -827,6 +855,8 @@ def __call__(
827855
logprobs=logprobs,
828856
echo=echo,
829857
stop=stop,
858+
frequency_penalty=frequency_penalty,
859+
presence_penalty=presence_penalty,
830860
repeat_penalty=repeat_penalty,
831861
top_k=top_k,
832862
stream=stream,
@@ -899,6 +929,8 @@ def create_chat_completion(
899929
stream: bool = False,
900930
stop: Optional[List[str]] = [],
901931
max_tokens: int = 256,
932+
presence_penalty: float = 0.0,
933+
frequency_penalty: float = 0.0,
902934
repeat_penalty: float = 1.1,
903935
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
904936
"""Generate a chat completion from a list of messages.
@@ -932,6 +964,8 @@ def create_chat_completion(
932964
stream=stream,
933965
max_tokens=max_tokens,
934966
repeat_penalty=repeat_penalty,
967+
presence_penalty=presence_penalty,
968+
frequency_penalty=frequency_penalty,
935969
)
936970
if stream:
937971
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

llama_cpp/server/app.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def create_completion(
214214
exclude={
215215
"model",
216216
"n",
217-
"frequency_penalty",
218-
"presence_penalty",
219217
"best_of",
220218
"logit_bias",
221219
"user",
@@ -315,8 +313,6 @@ def create_chat_completion(
315313
exclude={
316314
"model",
317315
"n",
318-
"presence_penalty",
319-
"frequency_penalty",
320316
"logit_bias",
321317
"user",
322318
}

0 commit comments

Comments
 (0)