Skip to content

Commit 4eb88f8

Browse files
committed
Put back the GGML pipeline and removed the GPTQ pipeline earlier commit had it backwards
1 parent 9a3ca91 commit 4eb88f8

File tree

1 file changed

+11
-28
lines changed

1 file changed

+11
-28
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -127,44 +127,27 @@ def __next__(self):
127127
if value != self.stop_signal:
128128
return value
129129

130-
131-
class GPTQPipeline(object):
130+
class GGMLPipeline(object):
132131
def __init__(self, model_name, **task):
133-
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
134-
from huggingface_hub import snapshot_download
135-
136-
model_path = snapshot_download(model_name)
132+
import ctransformers
137133

138-
quantized_config = BaseQuantizeConfig.from_pretrained(model_path)
139-
self.model = AutoGPTQForCausalLM.from_quantized(
140-
model_path, quantized_config=quantized_config, **task
134+
task.pop("model")
135+
task.pop("task")
136+
task.pop("device")
137+
self.model = ctransformers.AutoModelForCausalLM.from_pretrained(
138+
model_name, **task
141139
)
142-
if "use_fast_tokenizer" in task:
143-
self.tokenizer = AutoTokenizer.from_pretrained(
144-
model_path, use_fast=task.pop("use_fast_tokenizer")
145-
)
146-
else:
147-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
140+
self.tokenizer = None
148141
self.task = "text-generation"
149142

150143
def stream(self, inputs, **kwargs):
151-
streamer = TextIteratorStreamer(self.tokenizer)
152-
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
153-
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
154-
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
155-
thread.start()
156-
return streamer
144+
output = self.model(inputs[0], stream=True, **kwargs)
145+
return ThreadedGeneratorIterator(output, inputs[0])
157146

158147
def __call__(self, inputs, **kwargs):
159148
outputs = []
160149
for input in inputs:
161-
tokens = (
162-
self.tokenizer(input, return_tensors="pt")
163-
.to(self.model.device)
164-
.input_ids
165-
)
166-
token_ids = self.model.generate(input_ids=tokens, **kwargs)[0]
167-
outputs.append(self.tokenizer.decode(token_ids))
150+
outputs.append(self.model(input, **kwargs))
168151
return outputs
169152

170153

0 commit comments

Comments
 (0)