Skip to content

Commit e89764e

Browse files
borisdaymaawaelchliBorda
authored
feat(wandb): offset logging step when resuming (Lightning-AI#5050)
* feat(wandb): offset logging step when resuming * feat(wandb): output warnings * fix(wandb): allow step to be None * test(wandb): update tests * feat(wandb): display warning only once * style: fix PEP issues * tests(wandb): fix tests * tests(wandb): improve test * style: fix whitespace * feat: improve warning Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * feat(wandb): use variable from class instance Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * tests(wandb): check warnings * feat(wandb): use WarningCache * tests(wandb): fix tests * style: fix formatting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 4c34855 commit e89764e

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

pytorch_lightning/loggers/wandb.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
3333
from pytorch_lightning.utilities import rank_zero_only
34+
from pytorch_lightning.utilities.warning_utils import WarningCache
3435

3536

3637
class WandbLogger(LightningLoggerBase):
@@ -66,6 +67,9 @@ class WandbLogger(LightningLoggerBase):
6667
wandb_logger = WandbLogger()
6768
trainer = Trainer(logger=wandb_logger)
6869
70+
Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`,
71+
make sure to use `commit=False` so the logging step does not increase.
72+
6973
See Also:
7074
- `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/
7175
Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__
@@ -103,8 +107,9 @@ def __init__(
103107
self._log_model = log_model
104108
self._prefix = prefix
105109
self._kwargs = kwargs
106-
# logging multiple Trainer on a single W&B run (k-fold, etc)
110+
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
107111
self._step_offset = 0
112+
self.warning_cache = WarningCache()
108113

109114
def __getstate__(self):
110115
state = self.__dict__.copy()
@@ -134,6 +139,8 @@ def experiment(self) -> Run:
134139
self._experiment = wandb.init(
135140
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
136141
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
142+
# offset logging step when resuming a run
143+
self._step_offset = self._experiment.step
137144
# save checkpoints in wandb dir to upload on W&B servers
138145
if self._log_model:
139146
self._save_dir = self._experiment.dir
@@ -154,6 +161,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
154161
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
155162

156163
metrics = self._add_prefix(metrics)
164+
if step is not None and step + self._step_offset < self.experiment.step:
165+
self.warning_cache.warn('Trying to log at a previous step. Use `commit=False` when logging metrics manually.')
157166
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
158167

159168
@property

tests/loggers/test_all.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
7474
with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'):
7575
_test_loggers_fit_test(tmpdir, TestTubeLogger)
7676

77-
with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
77+
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
78+
wandb.run = None
79+
wandb.init().step = 0
7880
_test_loggers_fit_test(tmpdir, WandbLogger)
7981

8082

@@ -368,5 +370,7 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
368370
# WandB
369371
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
370372
logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix)
373+
wandb.run = None
374+
wandb.init().step = 0
371375
logger.log_metrics({"test": 1.0}, step=0)
372376
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)

tests/loggers/test_wandb.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222
from tests.base import EvalModelTemplate, BoringModel
2323

2424

25+
def get_warnings(recwarn):
26+
warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
27+
recwarn.clear()
28+
return warnings_text
29+
30+
2531
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
26-
def test_wandb_logger_init(wandb):
32+
def test_wandb_logger_init(wandb, recwarn):
2733
"""Verify that basic functionality of wandb logger works.
2834
Wandb doesn't work well with pytest so we have to mock it out here."""
2935

@@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
3440
wandb.init.assert_called_once()
3541
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
3642

43+
# mock wandb step
44+
wandb.init().step = 0
45+
3746
# test wandb.init not called if there is a W&B run
3847
wandb.init().log.reset_mock()
3948
wandb.init.reset_mock()
@@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
4958
logger.log_metrics({'acc': 1.0}, step=3)
5059
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
5160

61+
# log hyper parameters
5262
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
5363
wandb.init().config.update.assert_called_once_with(
5464
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
5565
allow_val_change=True,
5666
)
5767

68+
# watch a model
5869
logger.watch('model', 'log', 10)
5970
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
6071

72+
# verify warning for logging at a previous step
73+
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
74+
# current step from wandb should be 6 (last logged step)
75+
logger.experiment.step = 6
76+
# logging at step 2 should raise a warning (step_offset is still 3)
77+
logger.log_metrics({'acc': 1.0}, step=2)
78+
assert 'Trying to log at a previous step' in get_warnings(recwarn)
79+
# logging again at step 2 should not display again the same warning
80+
logger.log_metrics({'acc': 1.0}, step=2)
81+
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
82+
6183
assert logger.name == wandb.init().project_name()
6284
assert logger.version == wandb.init().id
6385

@@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
7193
class Experiment:
7294
""" """
7395
id = 'the_id'
96+
step = 0
7497

7598
def project_name(self):
7699
return 'the_project_name'
@@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
108131
assert logger.name is None
109132

110133
# mock return values of experiment
134+
wandb.run = None
135+
wandb.init().step = 0
111136
logger.experiment.id = '1'
112137
logger.experiment.project_name.return_value = 'project'
138+
logger.experiment.step = 0
113139

114140
for _ in range(2):
115141
_ = logger.experiment

0 commit comments

Comments
 (0)