@@ -127,44 +127,27 @@ def __next__(self):
127
127
if value != self .stop_signal :
128
128
return value
129
129
130
-
131
- class GPTQPipeline (object ):
130
+ class GGMLPipeline (object ):
132
131
def __init__ (self , model_name , ** task ):
133
- from auto_gptq import AutoGPTQForCausalLM , BaseQuantizeConfig
134
- from huggingface_hub import snapshot_download
135
-
136
- model_path = snapshot_download (model_name )
132
+ import ctransformers
137
133
138
- quantized_config = BaseQuantizeConfig .from_pretrained (model_path )
139
- self .model = AutoGPTQForCausalLM .from_quantized (
140
- model_path , quantized_config = quantized_config , ** task
134
+ task .pop ("model" )
135
+ task .pop ("task" )
136
+ task .pop ("device" )
137
+ self .model = ctransformers .AutoModelForCausalLM .from_pretrained (
138
+ model_name , ** task
141
139
)
142
- if "use_fast_tokenizer" in task :
143
- self .tokenizer = AutoTokenizer .from_pretrained (
144
- model_path , use_fast = task .pop ("use_fast_tokenizer" )
145
- )
146
- else :
147
- self .tokenizer = AutoTokenizer .from_pretrained (model_path )
140
+ self .tokenizer = None
148
141
self .task = "text-generation"
149
142
150
143
def stream (self , inputs , ** kwargs ):
151
- streamer = TextIteratorStreamer (self .tokenizer )
152
- inputs = self .tokenizer (inputs , return_tensors = "pt" ).to (self .model .device )
153
- generation_kwargs = dict (inputs , streamer = streamer , ** kwargs )
154
- thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
155
- thread .start ()
156
- return streamer
144
+ output = self .model (inputs [0 ], stream = True , ** kwargs )
145
+ return ThreadedGeneratorIterator (output , inputs [0 ])
157
146
158
147
def __call__ (self , inputs , ** kwargs ):
159
148
outputs = []
160
149
for input in inputs :
161
- tokens = (
162
- self .tokenizer (input , return_tensors = "pt" )
163
- .to (self .model .device )
164
- .input_ids
165
- )
166
- token_ids = self .model .generate (input_ids = tokens , ** kwargs )[0 ]
167
- outputs .append (self .tokenizer .decode (token_ids ))
150
+ outputs .append (self .model (input , ** kwargs ))
168
151
return outputs
169
152
170
153
0 commit comments