Skip to content

Commit f7f1eff

Browse files
committed
Multi-GPU eval/prediction supported
1 parent ba18616 commit f7f1eff

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

supar/parsers/parser.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
# -*- coding: utf-8 -*-
22

33
import os
4-
from datetime import datetime, timedelta
54
import shutil
5+
import tempfile
6+
from datetime import datetime, timedelta
7+
from functools import reduce
68

79
import dill
810
import supar
911
import torch
1012
import torch.distributed as dist
1113
from supar.utils import Config, Dataset
12-
import tempfile
1314
from supar.utils.field import Field
1415
from supar.utils.fn import download, get_rng_state, set_rng_state
1516
from supar.utils.logging import init_logger, logger, progress_bar
1617
from supar.utils.metric import Metric
1718
from supar.utils.parallel import DistributedDataParallel as DDP
18-
from supar.utils.parallel import is_master
19+
from supar.utils.parallel import gather, is_master
1920
from torch.optim import Adam
2021
from torch.optim.lr_scheduler import ExponentialLR
2122

@@ -118,12 +119,15 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, **kwargs):
118119
self.transform.train()
119120
logger.info("Loading the data")
120121
dataset = Dataset(self.transform, **args)
121-
dataset.build(batch_size, buckets, False, False, workers)
122+
dataset.build(batch_size, buckets, False, dist.is_initialized(), workers)
122123
logger.info(f"\n{dataset}")
123124

124125
logger.info("Evaluating the dataset")
125126
start = datetime.now()
126127
loss, metric = self._evaluate(dataset.loader)
128+
if dist.is_initialized():
129+
loss, metric = reduce(lambda x, y: (x[0] + y[0], x[1] + y[1]), gather((loss, metric)))
130+
loss = loss / dist.get_world_size()
127131
elapsed = datetime.now() - start
128132
logger.info(f"loss: {loss:.4f} - {metric}")
129133
logger.info(f"{elapsed}s elapsed, {len(dataset)/elapsed.total_seconds():.2f} Sents/s")
@@ -140,7 +144,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
140144

141145
logger.info("Loading the data")
142146
dataset = Dataset(self.transform, **args)
143-
dataset.build(batch_size, buckets, False, False, workers)
147+
dataset.build(batch_size, buckets, False, dist.is_initialized(), workers)
144148
logger.info(f"\n{dataset}")
145149

146150
logger.info("Making predictions on the dataset")
@@ -154,17 +158,25 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
154158
f.write(str(s) + '\n')
155159
elapsed = datetime.now() - start
156160

161+
if dist.is_initialized():
162+
dist.barrier()
163+
if args.cache:
164+
tdirs = gather(t) if dist.is_initialized() else (t,)
157165
if pred is not None and is_master():
158166
logger.info(f"Saving predicted results to {pred}")
159167
with open(pred, 'w') as f:
160168
# merge all predictions into one single file
161169
if args.cache:
162-
for s in progress_bar(sorted(os.listdir(t), key=lambda x: int(x))):
163-
with open(os.path.join(t, s)) as s:
170+
sentences = (os.path.join(i, s) for i in tdirs for s in os.listdir(i))
171+
for i in progress_bar(sorted(sentences, key=lambda x: int(os.path.basename(x)))):
172+
with open(i) as s:
164173
shutil.copyfileobj(s, f)
165174
else:
166175
for s in progress_bar(dataset):
167176
f.write(str(s) + '\n')
177+
# exit util all files have been merged
178+
if dist.is_initialized():
179+
dist.barrier()
168180
logger.info(f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s")
169181

170182
if not cache:

supar/utils/parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3+
from typing import Any, Iterable
4+
35
import torch.distributed as dist
46
import torch.nn as nn
57

@@ -27,3 +29,9 @@ def get_free_port():
2729
port = str(s.getsockname()[1])
2830
s.close()
2931
return port
32+
33+
34+
def gather(obj: Any) -> Iterable[Any]:
35+
objs = [None] * dist.get_world_size()
36+
dist.all_gather_object(objs, obj)
37+
return objs

0 commit comments

Comments
 (0)