From d5d787a7b4df58060783293a42b1118cac0fac07 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 7 Feb 2025 11:10:05 +0000 Subject: [PATCH 1/2] [Doc] Fix tutorials (#2768) (cherry picked from commit 75f113ff5c213389ba0f917633419184e7d0c76b) --- .github/workflows/docs.yml | 5 +++-- docs/source/conf.py | 4 ++-- torchrl/envs/utils.py | 2 ++ torchrl/trainers/trainers.py | 10 +++------- tutorials/sphinx-tutorials/coding_ddpg.py | 6 ++++++ tutorials/sphinx-tutorials/coding_dqn.py | 8 ++++++-- tutorials/sphinx-tutorials/pretrained_models.py | 4 ++-- 7 files changed, 24 insertions(+), 15 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5f99fb12ba6..17bf26fb86b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -26,7 +26,7 @@ jobs: build-docs: strategy: matrix: - python_version: ["3.10"] + python_version: ["3.9"] cuda_arch_version: ["12.4"] uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: @@ -60,7 +60,7 @@ jobs: bash ./miniconda.sh -b -f -p "${conda_dir}" eval "$(${conda_dir}/bin/conda shell.bash hook)" printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python=3.10 + conda create --prefix "${env_dir}" -y python=3.9 printf "* Activating\n" conda activate "${env_dir}" @@ -107,6 +107,7 @@ jobs: cd .. # 11. Build doc + export MAX_IDLE_COUNT=180 # Max 180 secs before killing an unresponsive collector cd ./docs # timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi # bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi diff --git a/docs/source/conf.py b/docs/source/conf.py index 35f5e5c3882..96b8b193fc8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -94,10 +94,10 @@ "filename_pattern": "reference/generated/tutorials/", # files to parse "notebook_images": "reference/generated/tutorials/media/", # images to parse "download_all_examples": True, - "abort_on_example_error": False, - "only_warn_on_example_error": True, + "abort_on_example_error": True, "show_memory": True, "capture_repr": ("_repr_html_", "__repr__"), # capture representations + "write_computation_times": True, } napoleon_use_ivar = True diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 39b0faa9692..56ac1687138 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -819,6 +819,8 @@ def check_env_specs( spec = Composite(shape=env.batch_size, device=env.device) td = last_td.select(*spec.keys(True, True), strict=True) if not spec.contains(td): + for k, v in spec.items(True): + assert v.contains(td[k]), f"{k} is not in {v} (val: {td[k]})" raise AssertionError( f"spec check failed at root for spec {name}={spec} and data {td}." ) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 83bd050ef96..65be247cd33 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -640,7 +640,7 @@ class ReplayBufferTrainer(TrainerHookBase): memmap (bool, optional): if ``True``, a memmap tensordict is created. Default is ``False``. device (device, optional): device where the samples must be placed. - Default is ``cpu``. + Default to ``None``. flatten_tensordicts (bool, optional): if ``True``, the tensordicts will be flattened (or equivalently masked with the valid mask obtained from the collector) before being passed to the replay buffer. Otherwise, @@ -666,7 +666,7 @@ def __init__( replay_buffer: TensorDictReplayBuffer, batch_size: Optional[int] = None, memmap: bool = False, - device: DEVICE_TYPING = "cpu", + device: DEVICE_TYPING | None = None, flatten_tensordicts: bool = False, max_dims: Optional[Sequence[int]] = None, ) -> None: @@ -695,15 +695,11 @@ def extend(self, batch: TensorDictBase) -> TensorDictBase: pads += [0, pad_value] batch = pad(batch, pads) batch = batch.cpu() - if self.memmap: - # We can already place the tensords on the device if they're memmap, - # as this is a lazy op - batch = batch.memmap_().to(self.device) self.replay_buffer.extend(batch) def sample(self, batch: TensorDictBase) -> TensorDictBase: sample = self.replay_buffer.sample(batch_size=self.batch_size) - return sample.to(self.device, non_blocking=True) + return sample.to(self.device) if self.device is not None else sample def update_priority(self, batch: TensorDictBase) -> None: self.replay_buffer.update_tensordict_priority(batch) diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 70176f9de4a..734fed2d74f 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -1185,6 +1185,12 @@ def ceil_div(x, y): collector.shutdown() del collector +try: + parallel_env.close() + del parallel_env +except Exception: + pass + ############################################################################### # Experiment results # ------------------ diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index a10e8c1169a..b0e244e4143 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -380,11 +380,12 @@ def make_model(dummy_env): # time must always have the same shape. -def get_replay_buffer(buffer_size, n_optim, batch_size): +def get_replay_buffer(buffer_size, n_optim, batch_size, device): replay_buffer = TensorDictReplayBuffer( batch_size=batch_size, storage=LazyMemmapStorage(buffer_size), prefetch=n_optim, + transform=lambda td: td.to(device), ) return replay_buffer @@ -660,7 +661,7 @@ def get_loss_module(actor, gamma): # requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which # can be cumbersome to implement. buffer_hook = ReplayBufferTrainer( - get_replay_buffer(buffer_size, n_optim, batch_size=batch_size), + get_replay_buffer(buffer_size, n_optim, batch_size=batch_size, device=device), flatten_tensordicts=True, ) buffer_hook.register(trainer) @@ -750,6 +751,9 @@ def print_csv_files_in_folder(folder_path): print_csv_files_in_folder(logger.experiment.log_dir) +trainer.shutdown() +del trainer + ############################################################################### # Conclusion and possible improvements # ------------------------------------ diff --git a/tutorials/sphinx-tutorials/pretrained_models.py b/tutorials/sphinx-tutorials/pretrained_models.py index 67d65c7876d..4de341c7378 100644 --- a/tutorials/sphinx-tutorials/pretrained_models.py +++ b/tutorials/sphinx-tutorials/pretrained_models.py @@ -37,7 +37,7 @@ import torch.cuda from tensordict.nn import TensorDictSequential from torch import nn -from torchrl.envs import R3MTransform, TransformedEnv +from torchrl.envs import Compose, R3MTransform, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor @@ -115,7 +115,7 @@ from torchrl.data import LazyMemmapStorage, ReplayBuffer storage = LazyMemmapStorage(1000) -rb = ReplayBuffer(storage=storage, transform=r3m) +rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m)) ############################################################################## # We can now collect the data (random rollouts for our purpose) and fill the replay From 8a1e393cc7b62bb2f75051613ab4163cdcd869c7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Feb 2025 09:51:37 +0000 Subject: [PATCH 2/2] Update docs.html for BATCHED_PIPE_TIMEOUT --- .github/workflows/docs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 17bf26fb86b..6c99cb9af05 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -108,6 +108,7 @@ jobs: # 11. Build doc export MAX_IDLE_COUNT=180 # Max 180 secs before killing an unresponsive collector + export BATCHED_PIPE_TIMEOUT=180 cd ./docs # timeout 7m bash -ic "MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi # bash -ic "PYOPENGL_PLATFORM=egl MUJOCO_GL=egl sphinx-build ./source _local_build" || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi