Skip to content

Commit eab0afc

Browse files
authored
[Trainer] implement gradient_accumulation_steps support in DeepSpeed integration (huggingface#10310)
* implement gradient_accumulation_steps support in DeepSpeed integration * typo * cleanup * cleanup
1 parent f991dae commit eab0afc

File tree

5 files changed

+159
-24
lines changed

5 files changed

+159
-24
lines changed

docs/source/main_classes/trainer.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,28 @@ Here is an example of the ``amp`` configuration:
830830
}
831831
832832
833+
Gradient Accumulation
834+
=======================================================================================================================
835+
836+
While normally DeepSpeed gets gradient accumulation configured with:
837+
838+
.. code-block:: json
839+
840+
{
841+
"gradient_accumulation_steps": 3,
842+
}
843+
844+
in this case, to enable gradient accumulation, pass the command line `--gradient_accumulation_steps` argument as normal
845+
and it will get injected into the DeepSpeed configuration.
846+
847+
If you try to add it directly to the configuration file, you will receive an error from the Trainer - this is because
848+
this setting is needed by the Trainer too, and so this approach ensures that there is a single way of setting this
849+
value and thus avoid potential subtle errors.
850+
851+
852+
853+
854+
833855

834856
Gradient Clipping
835857
=======================================================================================================================

examples/tests/deepspeed/ds_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"enabled": true,
44
"loss_scale": 0,
55
"loss_scale_window": 1000,
6+
"initial_scale_power": 32,
67
"hysteresis": 2,
78
"min_loss_scale": 1
89
},

examples/tests/deepspeed/test_deepspeed.py

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,19 @@
2323
TestCasePlus,
2424
execute_subprocess_async,
2525
get_gpu_count,
26-
mockenv,
26+
mockenv_context,
2727
require_torch_gpu,
2828
require_torch_multi_gpu,
2929
slow,
3030
)
3131
from transformers.trainer_utils import set_seed
3232

3333

34+
bindir = os.path.abspath(os.path.dirname(__file__))
35+
sys.path.append(f"{bindir}/../../../tests")
36+
from test_trainer import get_regression_trainer # noqa
37+
38+
3439
set_seed(42)
3540
MBART_TINY = "sshleifer/tiny-mbart"
3641

@@ -51,32 +56,96 @@ def require_deepspeed(test_case):
5156
return test_case
5257

5358

54-
@slow
5559
@require_deepspeed
5660
@require_torch_gpu
57-
class TestDeepSpeed(TestCasePlus):
61+
class TrainerIntegrationDeepSpeed(TestCasePlus):
62+
""" This class is for testing directly via get_regression_trainer """
63+
64+
def setUp(self):
65+
super().setUp()
66+
self.dist_env_1_gpu = dict(
67+
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
68+
)
69+
self.ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
5870

59-
# this setup emulates a notebook where a launcher needs to be emulated by hand
60-
@mockenv(MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1")
6171
def test_fake_notebook_no_launcher(self):
62-
sys.path.append(self.tests_dir_str)
63-
from test_trainer import get_regression_trainer
6472

65-
del sys.path[-1] # restore
66-
ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
73+
# this setup emulates a notebook where a launcher needs to be emulated by hand
74+
6775
with CaptureStd() as cs:
68-
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_file)
69-
trainer.train()
76+
with mockenv_context(**self.dist_env_1_gpu):
77+
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file)
78+
trainer.train()
7079
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
7180

81+
def test_gradient_accumulation(self):
82+
83+
# this test measures that we get identical weights and similar loss with:
84+
# 1. per_device_train_batch_size=8, gradient_accumulation_steps=1
85+
# 2. per_device_train_batch_size=4, gradient_accumulation_steps=2
86+
# since the 2nd should produce the effective batch of 1st, with the same results
87+
#
88+
# I can get an identical loss for a small train_len=32, plus the power of the initial
89+
# dynamic loss scale value set to:
90+
# "fp16.initial_scale_power": 1
91+
# plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file
92+
# but for some reason going to train_len=64 the weights, weights start to mismatch with this setup.
93+
# the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical
94+
95+
train_len = 64
96+
a = b = 0.0
97+
98+
with mockenv_context(**self.dist_env_1_gpu):
99+
no_grad_accum_trainer = get_regression_trainer(
100+
a=a,
101+
b=b,
102+
local_rank=0,
103+
train_len=train_len,
104+
deepspeed=self.ds_config_file,
105+
per_device_train_batch_size=8,
106+
gradient_accumulation_steps=1,
107+
)
108+
no_grad_accum_result = no_grad_accum_trainer.train()
109+
no_grad_accum_loss = no_grad_accum_result.training_loss
110+
no_grad_accum_a = no_grad_accum_trainer.model.a.item()
111+
no_grad_accum_b = no_grad_accum_trainer.model.b.item()
112+
# make sure the optimizer kicked in - if it hasn't changed from the original value of a then make train_len bigger
113+
self.assertNotEqual(no_grad_accum_a, a)
114+
115+
with mockenv_context(**self.dist_env_1_gpu):
116+
yes_grad_accum_trainer = get_regression_trainer(
117+
a=a,
118+
b=b,
119+
local_rank=0,
120+
train_len=train_len,
121+
deepspeed=self.ds_config_file,
122+
per_device_train_batch_size=4,
123+
gradient_accumulation_steps=2,
124+
)
125+
yes_grad_accum_result = yes_grad_accum_trainer.train()
126+
yes_grad_accum_loss = yes_grad_accum_result.training_loss
127+
yes_grad_accum_a = yes_grad_accum_trainer.model.a.item()
128+
yes_grad_accum_b = yes_grad_accum_trainer.model.b.item()
129+
self.assertNotEqual(yes_grad_accum_a, a)
130+
131+
# training with half the batch size but accumulation steps as 2 should give the same weights
132+
self.assertEqual(no_grad_accum_a, yes_grad_accum_a)
133+
self.assertEqual(no_grad_accum_b, yes_grad_accum_b)
134+
135+
# see the note above how to get identical loss on a small bs
136+
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
137+
138+
139+
@slow
140+
@require_deepspeed
141+
@require_torch_gpu
142+
class TestDeepSpeed(TestCasePlus):
143+
""" This class is for testing via an external script """
144+
72145
@require_torch_multi_gpu
73146
def test_basic_distributed(self):
74147
self.run_quick(distributed=True)
75148

76-
@require_torch_multi_gpu
77-
def test_grad_acum(self):
78-
self.run_quick(distributed=True, extra_args_str="--gradient_accumulation_steps 2")
79-
80149
def test_do_eval_no_train(self):
81150
# we should not fail if train is skipped
82151
output_dir = self.run_trainer(

src/transformers/testing_utils.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import inspect
1617
import logging
1718
import os
@@ -830,14 +831,49 @@ def tearDown(self):
830831

831832
def mockenv(**kwargs):
832833
"""
833-
this is a convenience wrapper, that allows this:
834+
this is a convenience wrapper, that allows this ::
835+
836+
@mockenv(RUN_SLOW=True, USE_TF=False)
837+
def test_something():
838+
run_slow = os.getenv("RUN_SLOW", False)
839+
use_tf = os.getenv("USE_TF", False)
834840
835-
@mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): run_slow = os.getenv("RUN_SLOW", False) use_tf =
836-
os.getenv("USE_TF", False)
837841
"""
838842
return unittest.mock.patch.dict(os.environ, kwargs)
839843

840844

845+
# from https://stackoverflow.com/a/34333710/9201239
846+
@contextlib.contextmanager
847+
def mockenv_context(*remove, **update):
848+
"""
849+
Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv
850+
851+
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
852+
853+
Args:
854+
remove: Environment variables to remove.
855+
update: Dictionary of environment variables and values to add/update.
856+
"""
857+
env = os.environ
858+
update = update or {}
859+
remove = remove or []
860+
861+
# List of environment variables being updated or removed.
862+
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
863+
# Environment variables and values to restore on exit.
864+
update_after = {k: env[k] for k in stomped}
865+
# Environment variables and values to remove on exit.
866+
remove_after = frozenset(k for k in update if k not in env)
867+
868+
try:
869+
env.update(update)
870+
[env.pop(k, None) for k in remove]
871+
yield
872+
finally:
873+
env.update(update_after)
874+
[env.pop(k) for k in remove_after]
875+
876+
841877
# --- pytest conf functions --- #
842878

843879
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once

src/transformers/trainer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def call_model_init(self, trial=None):
718718
def _wrap_model(self, model, training=True):
719719
# already initialized its own DDP and AMP
720720
if self.deepspeed:
721-
return model
721+
return self.deepspeed
722722

723723
# Mixed precision training with apex (torch < 1.6)
724724
if self.use_apex and training:
@@ -996,6 +996,10 @@ def train(
996996
tr_loss += self.training_step(model, inputs)
997997
self._total_flos += float(self.floating_point_ops(inputs))
998998

999+
# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
1000+
if self.deepspeed:
1001+
self.deepspeed.step()
1002+
9991003
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
10001004
# last step in epoch but step is always smaller than gradient_accumulation_steps
10011005
steps_in_epoch <= self.args.gradient_accumulation_steps
@@ -1021,7 +1025,7 @@ def train(
10211025

10221026
# Optimizer step
10231027
if self.deepspeed:
1024-
self.deepspeed.step()
1028+
pass # called outside the loop
10251029
elif is_torch_tpu_available():
10261030
xm.optimizer_step(self.optimizer)
10271031
elif self.use_amp:
@@ -1030,7 +1034,9 @@ def train(
10301034
else:
10311035
self.optimizer.step()
10321036

1033-
self.lr_scheduler.step()
1037+
if not self.deepspeed:
1038+
self.lr_scheduler.step()
1039+
10341040
model.zero_grad()
10351041
self.state.global_step += 1
10361042
self.state.epoch = epoch + (step + 1) / steps_in_epoch
@@ -1388,7 +1394,6 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
13881394
Return:
13891395
:obj:`torch.Tensor`: The tensor with training loss on this batch.
13901396
"""
1391-
13921397
model.train()
13931398
inputs = self._prepare_inputs(inputs)
13941399

@@ -1401,7 +1406,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
14011406
if self.args.n_gpu > 1:
14021407
loss = loss.mean() # mean() to average on multi-gpu parallel training
14031408

1404-
if self.args.gradient_accumulation_steps > 1:
1409+
if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
1410+
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
14051411
loss = loss / self.args.gradient_accumulation_steps
14061412

14071413
if self.use_amp:
@@ -1410,7 +1416,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
14101416
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
14111417
scaled_loss.backward()
14121418
elif self.deepspeed:
1413-
self.deepspeed.backward(loss)
1419+
# loss gets scaled under gradient_accumulation_steps in deepspeed
1420+
loss = self.deepspeed.backward(loss)
14141421
else:
14151422
loss.backward()
14161423

0 commit comments

Comments
 (0)