Skip to content

Commit 95e1e9a

Browse files
committed
Updated to work with hugging face tokens
1 parent e5eccec commit 95e1e9a

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ def __init__(self, model_name, **kwargs):
200200
# but that is only possible when the task is passed in, since if you pass the model
201201
# to the pipeline constructor, the task will no longer be inferred from the default...
202202

203+
# See: https://huggingface.co/docs/hub/security-tokens
204+
# This renaming is for backwards compatability
205+
if "use_auth_token" in kwargs:
206+
kwargs["token"] = kwargs.pop("use_auth_token")
207+
203208
if (
204209
"task" in kwargs
205210
and model_name is not None
@@ -230,9 +235,9 @@ def __init__(self, model_name, **kwargs):
230235
else:
231236
raise PgMLException(f"Unhandled task: {self.task}")
232237

233-
if "use_auth_token" in kwargs:
238+
if "token" in kwargs:
234239
self.tokenizer = AutoTokenizer.from_pretrained(
235-
model_name, use_auth_token=kwargs["use_auth_token"]
240+
model_name, use_auth_token=kwargs["token"]
236241
)
237242
else:
238243
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -302,18 +307,6 @@ def __call__(self, inputs, **kwargs):
302307
outputs = outputs[:, inputs["input_ids"].shape[1] :]
303308
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
304309
return outputs
305-
306-
# I don't think conversations support num_responses and/or maybe num_beams
307-
# Also this is not processed in parallel / truly batched it seems
308-
# num_conversations = 1
309-
# if "num_return_sequences" in kwargs:
310-
# num_conversations = kwargs.pop("num_return_sequences")
311-
# conversations = [Conversation(inputs) for _ in range(0, num_conversations)]
312-
# conversations = self.pipe(conversations, **kwargs)
313-
# outputs = []
314-
# for conversation in conversations:
315-
# outputs.append(conversation.messages[-1]["content"])
316-
# return outputs
317310
else:
318311
return self.pipe(inputs, **kwargs)
319312

0 commit comments

Comments
 (0)