22
22
from tests .base import EvalModelTemplate , BoringModel
23
23
24
24
25
+ def get_warnings (recwarn ):
26
+ warnings_text = '\n ' .join (str (w .message ) for w in recwarn .list )
27
+ recwarn .clear ()
28
+ return warnings_text
29
+
30
+
25
31
@mock .patch ('pytorch_lightning.loggers.wandb.wandb' )
26
- def test_wandb_logger_init (wandb ):
32
+ def test_wandb_logger_init (wandb , recwarn ):
27
33
"""Verify that basic functionality of wandb logger works.
28
34
Wandb doesn't work well with pytest so we have to mock it out here."""
29
35
@@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
34
40
wandb .init .assert_called_once ()
35
41
wandb .init ().log .assert_called_once_with ({'acc' : 1.0 }, step = None )
36
42
43
+ # mock wandb step
44
+ wandb .init ().step = 0
45
+
37
46
# test wandb.init not called if there is a W&B run
38
47
wandb .init ().log .reset_mock ()
39
48
wandb .init .reset_mock ()
@@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
49
58
logger .log_metrics ({'acc' : 1.0 }, step = 3 )
50
59
wandb .init ().log .assert_called_with ({'acc' : 1.0 }, step = 6 )
51
60
61
+ # log hyper parameters
52
62
logger .log_hyperparams ({'test' : None , 'nested' : {'a' : 1 }, 'b' : [2 , 3 , 4 ]})
53
63
wandb .init ().config .update .assert_called_once_with (
54
64
{'test' : 'None' , 'nested/a' : 1 , 'b' : [2 , 3 , 4 ]},
55
65
allow_val_change = True ,
56
66
)
57
67
68
+ # watch a model
58
69
logger .watch ('model' , 'log' , 10 )
59
70
wandb .init ().watch .assert_called_once_with ('model' , log = 'log' , log_freq = 10 )
60
71
72
+ # verify warning for logging at a previous step
73
+ assert 'Trying to log at a previous step' not in get_warnings (recwarn )
74
+ # current step from wandb should be 6 (last logged step)
75
+ logger .experiment .step = 6
76
+ # logging at step 2 should raise a warning (step_offset is still 3)
77
+ logger .log_metrics ({'acc' : 1.0 }, step = 2 )
78
+ assert 'Trying to log at a previous step' in get_warnings (recwarn )
79
+ # logging again at step 2 should not display again the same warning
80
+ logger .log_metrics ({'acc' : 1.0 }, step = 2 )
81
+ assert 'Trying to log at a previous step' not in get_warnings (recwarn )
82
+
61
83
assert logger .name == wandb .init ().project_name ()
62
84
assert logger .version == wandb .init ().id
63
85
@@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
71
93
class Experiment :
72
94
""" """
73
95
id = 'the_id'
96
+ step = 0
74
97
75
98
def project_name (self ):
76
99
return 'the_project_name'
@@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
108
131
assert logger .name is None
109
132
110
133
# mock return values of experiment
134
+ wandb .run = None
135
+ wandb .init ().step = 0
111
136
logger .experiment .id = '1'
112
137
logger .experiment .project_name .return_value = 'project'
138
+ logger .experiment .step = 0
113
139
114
140
for _ in range (2 ):
115
141
_ = logger .experiment
0 commit comments