Skip to content

Commit c17e7cd

Browse files
authored
Add ability to get a list of supported pipeline tasks (huggingface#14732)
1 parent 3d66146 commit c17e7cd

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

src/transformers/commands/run.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from argparse import ArgumentParser
1616

17-
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, PipelineDataFormat, pipeline
17+
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
1818
from ..utils import logging
1919
from . import BaseTransformersCLICommand
2020

@@ -63,9 +63,7 @@ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
6363
@staticmethod
6464
def register_subcommand(parser: ArgumentParser):
6565
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
66-
run_parser.add_argument(
67-
"--task", choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()), help="Task to run"
68-
)
66+
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
6967
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
7068
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
7169
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")

src/transformers/commands/serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from argparse import ArgumentParser, Namespace
1616
from typing import Any, List, Optional
1717

18-
from ..pipelines import SUPPORTED_TASKS, TASK_ALIASES, Pipeline, pipeline
18+
from ..pipelines import Pipeline, get_supported_tasks, pipeline
1919
from ..utils import logging
2020
from . import BaseTransformersCLICommand
2121

@@ -104,7 +104,7 @@ def register_subcommand(parser: ArgumentParser):
104104
serve_parser.add_argument(
105105
"--task",
106106
type=str,
107-
choices=list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()),
107+
choices=get_supported_tasks(),
108108
help="The task to run the pipeline on",
109109
)
110110
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")

src/transformers/pipelines/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
2222
import warnings
23-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
23+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
2424

2525
from ..configuration_utils import PretrainedConfig
2626
from ..feature_extraction_utils import PreTrainedFeatureExtractor
@@ -252,6 +252,15 @@
252252
}
253253

254254

255+
def get_supported_tasks() -> List[str]:
256+
"""
257+
Returns a list of supported task strings.
258+
"""
259+
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
260+
supported_tasks.sort()
261+
return supported_tasks
262+
263+
255264
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
256265
tmp = io.BytesIO()
257266
headers = {}
@@ -320,9 +329,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
320329
return targeted_task, (tokens[1], tokens[3])
321330
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
322331

323-
raise KeyError(
324-
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
325-
)
332+
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
326333

327334

328335
def pipeline(

0 commit comments

Comments
 (0)