diff --git a/localstack-core/localstack/services/transcribe/provider.py b/localstack-core/localstack/services/transcribe/provider.py index b24f960d8e867..e26705386abb1 100644 --- a/localstack-core/localstack/services/transcribe/provider.py +++ b/localstack-core/localstack/services/transcribe/provider.py @@ -226,7 +226,7 @@ def download_model(name: str): model_path = LANGUAGE_MODEL_DIR / name with _DL_LOCK: - if (model_path).exists(): + if model_path.exists(): return else: model_path.mkdir(parents=True) diff --git a/tests/aws/conftest.py b/tests/aws/conftest.py index 7d9669340fc0d..d8a9ace182578 100644 --- a/tests/aws/conftest.py +++ b/tests/aws/conftest.py @@ -45,11 +45,19 @@ def pytest_runtestloop(session): ) test_init_functions.add(opensearch_install_async) - if any(opensearch_test in parent_name for opensearch_test in ["test_es", "firehose"]): + + if any(es_test in parent_name for es_test in ["elasticsearch", "firehose"]): from tests.aws.services.es.test_es import install_async as es_install_async test_init_functions.add(es_install_async) + if "transcribe" in parent_name: + from tests.aws.services.transcribe.test_transcribe import ( + install_async as transcribe_install_async, + ) + + test_init_functions.add(transcribe_install_async) + # add init functions for certain tests that download/install things for test_class in test_classes: # set flag that terraform will be used diff --git a/tests/aws/services/transcribe/test_transcribe.py b/tests/aws/services/transcribe/test_transcribe.py index 1ab7b1b16bf4f..9ba22484369f7 100644 --- a/tests/aws/services/transcribe/test_transcribe.py +++ b/tests/aws/services/transcribe/test_transcribe.py @@ -1,5 +1,7 @@ import logging import os +import threading +import time from urllib.parse import urlparse import pytest @@ -7,15 +9,81 @@ from localstack.aws.api.transcribe import BadRequestException, ConflictException, NotFoundException from localstack.aws.connect import ServiceLevelClientFactory +from localstack.packages.ffmpeg import ffmpeg_package +from localstack.services.transcribe.packages import vosk_package +from localstack.services.transcribe.provider import LANGUAGE_MODELS, TranscribeProvider from localstack.testing.pytest import markers from localstack.utils.files import new_tmp_file from localstack.utils.strings import short_uid, to_str from localstack.utils.sync import poll_condition, retry +from localstack.utils.threads import start_worker_thread BASEDIR = os.path.abspath(os.path.dirname(__file__)) LOG = logging.getLogger(__name__) +# Lock and event to ensure that the installation is executed before the tests +vosk_installed = threading.Event() +ffmpeg_installed = threading.Event() +installation_errored = threading.Event() + +INSTALLATION_TIMEOUT = 5 * 60 +PRE_DOWNLOAD_LANGUAGE_CODE_MODELS = ["en-GB"] + + +def install_async(): + """ + Installs the default ffmpeg and vosk versions in a worker thread. + """ + if vosk_installed.is_set() and ffmpeg_installed.is_set(): + return + + def install_vosk(*args): + if vosk_installed.is_set(): + return + try: + LOG.info("installing Vosk default version") + vosk_package.install() + LOG.info("done installing Vosk default version") + LOG.info("downloading Vosk models used in test: %s", PRE_DOWNLOAD_LANGUAGE_CODE_MODELS) + for language_code in PRE_DOWNLOAD_LANGUAGE_CODE_MODELS: + model_name = LANGUAGE_MODELS[language_code] + # downloading the model takes quite a while sometimes + TranscribeProvider.download_model(model_name) + LOG.info( + "done downloading Vosk model '%s' for language code '%s'", + model_name, + language_code, + ) + LOG.info("done downloading all Vosk models used in test") + except Exception: + LOG.exception("Error during installation of Vosk dependencies") + installation_errored.set() + # we also set the other event to quickly stop the polling + ffmpeg_installed.set() + finally: + vosk_installed.set() + + def install_ffmpeg(*args): + if ffmpeg_installed.is_set(): + return + try: + LOG.info("installing ffmpeg default version") + ffmpeg_package.install() + LOG.info("done ffmpeg default version") + except Exception: + LOG.exception("Error during installation of Vosk dependencies") + installation_errored.set() + # we also set the other event to quickly stop the polling + vosk_installed.set() + finally: + ffmpeg_installed.set() + + # we parallelize the installation of the dependencies + # TODO: we could maybe use a ThreadPoolExecutor to use Future instead of manually checking + start_worker_thread(install_vosk, name="vosk-install-async") + start_worker_thread(install_ffmpeg, name="ffmpeg-install-async") + @pytest.fixture(autouse=True) def transcribe_snapshot_transformer(snapshot): @@ -23,6 +91,24 @@ def transcribe_snapshot_transformer(snapshot): class TestTranscribe: + @pytest.fixture(scope="class", autouse=True) + def pre_install_dependencies(self): + if not ffmpeg_installed.is_set() or not vosk_installed.is_set(): + install_async() + + start = int(time.time()) + assert vosk_installed.wait( + timeout=INSTALLATION_TIMEOUT + ), "gave up waiting for Vosk to install" + elapsed = int(time.time() - start) + assert ffmpeg_installed.wait( + timeout=INSTALLATION_TIMEOUT - elapsed + ), "gave up waiting for ffmpeg to install" + LOG.info("Spent %s seconds downloading transcribe dependencies", int(time.time() - start)) + + assert not installation_errored.is_set(), "installation of transcribe dependencies failed" + yield + @staticmethod def _wait_transcription_job( transcribe_client: ServiceLevelClientFactory, transcribe_job_name: str