diff --git a/tests/refresh_script.py b/tests/refresh_script.py new file mode 100644 index 0000000000..605c6bde33 --- /dev/null +++ b/tests/refresh_script.py @@ -0,0 +1,23 @@ +import sys +import time + +from tuf.ngclient import Updater + +print(f"Fetching metadata {sys.argv[1]} times:") +print(f" metadata dir: {sys.argv[2]}") +print(f" metadata url: {sys.argv[3]}") + +start = time.time() + +for i in range(int(sys.argv[1])): + try: + refresh_start = time.time() + u = Updater(metadata_dir=sys.argv[2], metadata_base_url=sys.argv[3]) + # file3.txt is delegated so we end up exercising all metadata load paths + u.get_targetinfo("file3.txt") + except OSError as e: + print( + f"Failed on iteration {i}, " + f"{time.time() - refresh_start} secs elapsed ({time.time() - start} total)" + ) + raise e diff --git a/tests/test_updater_delegation_graphs.py b/tests/test_updater_delegation_graphs.py index 770a1b3d71..f6a6686283 100644 --- a/tests/test_updater_delegation_graphs.py +++ b/tests/test_updater_delegation_graphs.py @@ -136,7 +136,9 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.metadata_dir) if e.is_file() + e.name + for e in os.scandir(self.metadata_dir) + if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index 50ef5ee3be..0aae75ac71 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -8,6 +8,7 @@ import logging import os import shutil +import subprocess import sys import tempfile import unittest @@ -157,7 +158,9 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.client_directory) if e.is_file() + e.name + for e in os.scandir(self.client_directory) + if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) @@ -353,6 +356,46 @@ def test_user_agent(self) -> None: self.assertEqual(ua[:23], "MyApp/1.2.3 python-tuf/") + def test_parallel_updaters(self) -> None: + # Refresh many updaters in parallel many times, using the same local metadata cache. + # This should reveal race conditions. + + iterations = 50 + process_count = 10 + + project_root_dir = os.path.dirname(utils.TESTS_DIR) + + command = [ + sys.executable, + "-m", + "tests.refresh_script", + str(iterations), + self.client_directory, + self.metadata_url, + ] + + procs = [ + subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=project_root_dir, + ) + for _ in range(process_count) + ] + + errout = "" + for proc in procs: + stdout, stderr = proc.communicate() + if proc.returncode != 0: + errout += "Parallel Refresh script failed:" + errout += f"\nprocess stdout: \n{stdout.decode()}" + errout += f"\nprocess stderr: \n{stderr.decode()}" + if errout: + self.fail( + f"One or more scripts failed parallel refresh test:\n{errout}" + ) + if __name__ == "__main__": utils.configure_test_logging(sys.argv) diff --git a/tests/test_updater_top_level_update.py b/tests/test_updater_top_level_update.py index 76c74d4b57..c588661220 100644 --- a/tests/test_updater_top_level_update.py +++ b/tests/test_updater_top_level_update.py @@ -94,7 +94,9 @@ def _assert_files_exist(self, roles: Iterable[str]) -> None: """Assert that local metadata files match 'roles'""" expected_files = [f"{role}.json" for role in roles] found_files = [ - e.name for e in os.scandir(self.metadata_dir) if e.is_file() + e.name + for e in os.scandir(self.metadata_dir) + if e.is_file() and e.name != ".lock" ] self.assertListEqual(sorted(found_files), sorted(expected_files)) @@ -133,8 +135,7 @@ def test_cached_root_missing_without_bootstrap(self) -> None: self._run_refresh(skip_bootstrap=True) # Metadata dir is empty - with self.assertRaises(FileNotFoundError): - os.listdir(self.metadata_dir) + self._assert_files_exist([]) def test_trusted_root_expired(self) -> None: # Create an expired root version @@ -644,14 +645,16 @@ def test_not_loading_targets_twice(self, wrapped_open: MagicMock) -> None: wrapped_open.reset_mock() # First time looking for "somepath", only 'role1' must be loaded + # (and ".lock" for metadata locking) updater.get_targetinfo("somepath") - wrapped_open.assert_called_once_with( + self.assertEqual(wrapped_open.call_count, 2) + wrapped_open.assert_called_with( os.path.join(self.metadata_dir, "role1.json"), "rb" ) wrapped_open.reset_mock() # Second call to get_targetinfo, all metadata is already loaded updater.get_targetinfo("somepath") - wrapped_open.assert_not_called() + self.assertEqual(wrapped_open.call_count, 1) def test_snapshot_rollback_with_local_snapshot_hash_mismatch(self) -> None: # Test triggering snapshot rollback check on a newly downloaded snapshot @@ -709,6 +712,7 @@ def test_load_metadata_from_cache(self, wrapped_open: MagicMock) -> None: root_dir = os.path.join(self.metadata_dir, "root_history") wrapped_open.assert_has_calls( [ + call(os.path.join(self.metadata_dir, ".lock"), "wb"), call(os.path.join(root_dir, "2.root.json"), "rb"), call(os.path.join(self.metadata_dir, "timestamp.json"), "rb"), call(os.path.join(self.metadata_dir, "snapshot.json"), "rb"), diff --git a/tests/test_updater_validation.py b/tests/test_updater_validation.py index b9d6bb3cc7..8020e69de9 100644 --- a/tests/test_updater_validation.py +++ b/tests/test_updater_validation.py @@ -53,9 +53,9 @@ def test_local_target_storage_fail(self) -> None: def test_non_existing_metadata_dir(self) -> None: with self.assertRaises(FileNotFoundError): - # Initialize Updater with non-existing metadata_dir + # Initialize Updater with non-existing metadata_dir and no bootstrap root Updater( - "non_existing_metadata_dir", + f"{self.temp_dir.name}/non_existing_metadata_dir", "https://example.com/metadata/", fetcher=self.sim, ) diff --git a/tests/utils.py b/tests/utils.py index bbfb07dbaa..727abadb3e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -161,7 +161,7 @@ def cleanup_metadata_dir(path: str) -> None: for entry in it: if entry.name == "root_history": cleanup_metadata_dir(entry.path) - elif entry.name.endswith(".json"): + elif entry.name.endswith(".json") or entry.name == ".lock": os.remove(entry.path) else: raise ValueError(f"Unexpected local metadata file {entry.path}") diff --git a/tox.ini b/tox.ini index 7ef098ba3c..edb657172a 100644 --- a/tox.ini +++ b/tox.ini @@ -11,8 +11,8 @@ skipsdist = true [testenv] commands = python3 --version - python3 -m coverage run -m unittest - python3 -m coverage report -m --fail-under 97 + python3 -m coverage run -m unittest -v + python3 -m coverage report -m --fail-under 96 deps = -r{toxinidir}/requirements/test.txt diff --git a/tuf/ngclient/_internal/file_lock.py b/tuf/ngclient/_internal/file_lock.py new file mode 100644 index 0000000000..d970c17531 --- /dev/null +++ b/tuf/ngclient/_internal/file_lock.py @@ -0,0 +1,56 @@ +import logging +import time +from collections.abc import Iterator +from contextlib import contextmanager +from typing import IO + +logger = logging.getLogger(__name__) + +try: + # advisory file locking for posix + import fcntl + + @contextmanager + def lock_file(path: str) -> Iterator[IO]: + with open(path, "wb") as f: + fcntl.lockf(f, fcntl.LOCK_EX) + yield f + +except ModuleNotFoundError: + # Windows file locking, in belt-and-suspenders-from-Temu style: + # Use a loop that tries to open the lockfile for 30 secs, but also + # use msvcrt.locking(). + # * since open() usually just fails when another process has the file open + # msvcrt.locking() almost never gets called when there is a lock. open() + # sometimes succeeds for multiple processes though + # * msvcrt.locking() does not even block until file is available: it just + # tries once per second in a non-blocking manner for 10 seconds. So if + # another process keeps opening the file it's unlikely that we actually + # get the lock + import msvcrt + + @contextmanager + def lock_file(path: str) -> Iterator[IO]: + err = None + locked = False + for _ in range(100): + try: + with open(path, "wb") as f: + msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1) + locked = True + yield f + return + except FileNotFoundError: + # could be from yield or from open() -- either way we bail + raise + except OSError as e: + if locked: + # yield has raised, let's not continue loop + raise e + err = e + logger.warning("Unsuccessful lock attempt for %s: %s", path, e) + time.sleep(0.3) + + # raise the last failure if we never got a lock + if err is not None: + raise err diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index a98e799ce4..e96192a16d 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -29,10 +29,6 @@ * ``Updater.download_target()`` downloads a target file and ensures it is verified correct by the metadata. -Note that applications using ``Updater`` should be 'single instance' -applications: running multiple instances that use the same cache directories at -the same time is not supported. - A simple example of using the Updater to implement a Python TUF client that downloads target files is available in `examples/client `_. @@ -64,11 +60,14 @@ from tuf.api import exceptions from tuf.api.metadata import Root, Snapshot, TargetFile, Targets, Timestamp +from tuf.ngclient._internal.file_lock import lock_file from tuf.ngclient._internal.trusted_metadata_set import TrustedMetadataSet from tuf.ngclient.config import EnvelopeType, UpdaterConfig from tuf.ngclient.urllib3_fetcher import Urllib3Fetcher if TYPE_CHECKING: + from collections.abc import Iterator + from tuf.ngclient.fetcher import FetcherInterface logger = logging.getLogger(__name__) @@ -131,16 +130,30 @@ def __init__( f"got '{self.config.envelope_type}'" ) - if not bootstrap: - # if no root was provided, use the cached non-versioned root.json - bootstrap = self._load_local_metadata(Root.type) + # Ensure the whole metadata directory structure exists + rootdir = Path(self._dir, "root_history") + rootdir.mkdir(exist_ok=True, parents=True) - # Load the initial root, make sure it's cached - self._trusted_set = TrustedMetadataSet( - bootstrap, self.config.envelope_type - ) - self._persist_root(self._trusted_set.root.version, bootstrap) - self._update_root_symlink() + with self._lock_metadata(): + if not bootstrap: + # if no root was provided, use the cached non-versioned root + bootstrap = self._load_local_metadata(Root.type) + + # Load the initial root, make sure it's cached + self._trusted_set = TrustedMetadataSet( + bootstrap, self.config.envelope_type + ) + self._persist_root(self._trusted_set.root.version, bootstrap) + self._update_root_symlink() + + @contextlib.contextmanager + def _lock_metadata(self) -> Iterator[None]: + """Context manager for locking the metadata directory.""" + + logger.debug("Getting metadata lock...") + with lock_file(os.path.join(self._dir, ".lock")): + yield + logger.debug("Released metadata lock") def refresh(self) -> None: """Refresh top-level metadata. @@ -166,10 +179,11 @@ def refresh(self) -> None: DownloadError: Download of a metadata file failed in some way """ - self._load_root() - self._load_timestamp() - self._load_snapshot() - self._load_targets(Targets.type, Root.type) + with self._lock_metadata(): + self._load_root() + self._load_timestamp() + self._load_snapshot() + self._load_targets(Targets.type, Root.type) def _generate_target_file_path(self, targetinfo: TargetFile) -> str: if self.target_dir is None: @@ -205,9 +219,14 @@ def get_targetinfo(self, target_path: str) -> TargetFile | None: ``TargetFile`` instance or ``None``. """ - if Targets.type not in self._trusted_set: - self.refresh() - return self._preorder_depth_first_walk(target_path) + with self._lock_metadata(): + if Targets.type not in self._trusted_set: + # implicit refresh + self._load_root() + self._load_timestamp() + self._load_snapshot() + self._load_targets(Targets.type, Root.type) + return self._preorder_depth_first_walk(target_path) def find_cached_target( self, @@ -295,7 +314,7 @@ def download_target( targetinfo.verify_length_and_hashes(target_file) target_file.seek(0) - with open(filepath, "wb") as destination_file: + with lock_file(filepath) as destination_file: shutil.copyfileobj(target_file, destination_file) logger.debug("Downloaded target %s", targetinfo.path) @@ -335,7 +354,6 @@ def _persist_root(self, version: int, data: bytes) -> None: "root_history/1.root.json"). """ rootdir = Path(self._dir, "root_history") - rootdir.mkdir(exist_ok=True, parents=True) self._persist_file(str(rootdir / f"{version}.root.json"), data) def _persist_file(self, filename: str, data: bytes) -> None: