|
125 | 125 | "tf": (),
|
126 | 126 | "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
127 | 127 | "default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
|
| 128 | + "type": "audio", |
128 | 129 | },
|
129 | 130 | "automatic-speech-recognition": {
|
130 | 131 | "impl": AutomaticSpeechRecognitionPipeline,
|
131 | 132 | "tf": (),
|
132 | 133 | "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
|
133 | 134 | "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
|
| 135 | + "type": "multimodal", |
134 | 136 | },
|
135 | 137 | "feature-extraction": {
|
136 | 138 | "impl": FeatureExtractionPipeline,
|
137 | 139 | "tf": (TFAutoModel,) if is_tf_available() else (),
|
138 | 140 | "pt": (AutoModel,) if is_torch_available() else (),
|
139 | 141 | "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
| 142 | + "type": "multimodal", |
140 | 143 | },
|
141 | 144 | "text-classification": {
|
142 | 145 | "impl": TextClassificationPipeline,
|
|
148 | 151 | "tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
149 | 152 | },
|
150 | 153 | },
|
| 154 | + "type": "text", |
151 | 155 | },
|
152 | 156 | "token-classification": {
|
153 | 157 | "impl": TokenClassificationPipeline,
|
|
159 | 163 | "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
160 | 164 | },
|
161 | 165 | },
|
| 166 | + "type": "text", |
162 | 167 | },
|
163 | 168 | "question-answering": {
|
164 | 169 | "impl": QuestionAnsweringPipeline,
|
|
167 | 172 | "default": {
|
168 | 173 | "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
169 | 174 | },
|
| 175 | + "type": "text", |
170 | 176 | },
|
171 | 177 | "table-question-answering": {
|
172 | 178 | "impl": TableQuestionAnsweringPipeline,
|
|
179 | 185 | "tf": "google/tapas-base-finetuned-wtq",
|
180 | 186 | },
|
181 | 187 | },
|
| 188 | + "type": "text", |
182 | 189 | },
|
183 | 190 | "fill-mask": {
|
184 | 191 | "impl": FillMaskPipeline,
|
185 | 192 | "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
186 | 193 | "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
187 | 194 | "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
| 195 | + "type": "text", |
188 | 196 | },
|
189 | 197 | "summarization": {
|
190 | 198 | "impl": SummarizationPipeline,
|
191 | 199 | "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
192 | 200 | "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
193 | 201 | "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
| 202 | + "type": "text", |
194 | 203 | },
|
195 | 204 | # This task is a special case as it's parametrized by SRC, TGT languages.
|
196 | 205 | "translation": {
|
|
202 | 211 | ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
203 | 212 | ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
204 | 213 | },
|
| 214 | + "type": "text", |
205 | 215 | },
|
206 | 216 | "text2text-generation": {
|
207 | 217 | "impl": Text2TextGenerationPipeline,
|
208 | 218 | "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
209 | 219 | "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
210 | 220 | "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
| 221 | + "type": "text", |
211 | 222 | },
|
212 | 223 | "text-generation": {
|
213 | 224 | "impl": TextGenerationPipeline,
|
214 | 225 | "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
215 | 226 | "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
216 | 227 | "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
| 228 | + "type": "text", |
217 | 229 | },
|
218 | 230 | "zero-shot-classification": {
|
219 | 231 | "impl": ZeroShotClassificationPipeline,
|
|
224 | 236 | "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
225 | 237 | "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
226 | 238 | },
|
| 239 | + "type": "text", |
227 | 240 | },
|
228 | 241 | "conversational": {
|
229 | 242 | "impl": ConversationalPipeline,
|
230 | 243 | "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
231 | 244 | "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
232 | 245 | "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
| 246 | + "type": "text", |
233 | 247 | },
|
234 | 248 | "image-classification": {
|
235 | 249 | "impl": ImageClassificationPipeline,
|
236 | 250 | "tf": (),
|
237 | 251 | "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
238 | 252 | "default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
| 253 | + "type": "image", |
239 | 254 | },
|
240 | 255 | "image-segmentation": {
|
241 | 256 | "impl": ImageSegmentationPipeline,
|
242 | 257 | "tf": (),
|
243 | 258 | "pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
|
244 | 259 | "default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
|
| 260 | + "type": "image", |
245 | 261 | },
|
246 | 262 | "object-detection": {
|
247 | 263 | "impl": ObjectDetectionPipeline,
|
248 | 264 | "tf": (),
|
249 | 265 | "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
250 | 266 | "default": {"model": {"pt": "facebook/detr-resnet-50"}},
|
| 267 | + "type": "image", |
251 | 268 | },
|
252 | 269 | }
|
253 | 270 |
|
| 271 | +NO_FEATURE_EXTRACTOR_TASKS = set() |
| 272 | +NO_TOKENIZER_TASKS = set() |
| 273 | +for task, values in SUPPORTED_TASKS.items(): |
| 274 | + if values["type"] == "text": |
| 275 | + NO_FEATURE_EXTRACTOR_TASKS.add(task) |
| 276 | + elif values["type"] in {"audio", "image"}: |
| 277 | + NO_TOKENIZER_TASKS.add(task) |
| 278 | + elif values["type"] != "multimodal": |
| 279 | + raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") |
| 280 | + |
254 | 281 |
|
255 | 282 | def get_supported_tasks() -> List[str]:
|
256 | 283 | """
|
@@ -528,12 +555,14 @@ def pipeline(
|
528 | 555 | load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
529 | 556 | load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
530 | 557 |
|
531 |
| - if task in {"audio-classification", "image-classification"}: |
| 558 | + if task in NO_TOKENIZER_TASKS: |
532 | 559 | # These will never require a tokenizer.
|
533 | 560 | # the model on the other hand might have a tokenizer, but
|
534 | 561 | # the files could be missing from the hub, instead of failing
|
535 | 562 | # on such repos, we just force to not load it.
|
536 | 563 | load_tokenizer = False
|
| 564 | + if task in NO_FEATURE_EXTRACTOR_TASKS: |
| 565 | + load_feature_extractor = False |
537 | 566 |
|
538 | 567 | if load_tokenizer:
|
539 | 568 | # Try to infer tokenizer from model or config name (if provided as str)
|
|
0 commit comments