1
1
# -*- coding: utf-8 -*-
2
2
3
3
import os
4
- from datetime import datetime , timedelta
5
4
import shutil
5
+ import tempfile
6
+ from datetime import datetime , timedelta
7
+ from functools import reduce
6
8
7
9
import dill
8
10
import supar
9
11
import torch
10
12
import torch .distributed as dist
11
13
from supar .utils import Config , Dataset
12
- import tempfile
13
14
from supar .utils .field import Field
14
15
from supar .utils .fn import download , get_rng_state , set_rng_state
15
16
from supar .utils .logging import init_logger , logger , progress_bar
16
17
from supar .utils .metric import Metric
17
18
from supar .utils .parallel import DistributedDataParallel as DDP
18
- from supar .utils .parallel import is_master
19
+ from supar .utils .parallel import gather , is_master
19
20
from torch .optim import Adam
20
21
from torch .optim .lr_scheduler import ExponentialLR
21
22
@@ -118,12 +119,15 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, **kwargs):
118
119
self .transform .train ()
119
120
logger .info ("Loading the data" )
120
121
dataset = Dataset (self .transform , ** args )
121
- dataset .build (batch_size , buckets , False , False , workers )
122
+ dataset .build (batch_size , buckets , False , dist . is_initialized () , workers )
122
123
logger .info (f"\n { dataset } " )
123
124
124
125
logger .info ("Evaluating the dataset" )
125
126
start = datetime .now ()
126
127
loss , metric = self ._evaluate (dataset .loader )
128
+ if dist .is_initialized ():
129
+ loss , metric = reduce (lambda x , y : (x [0 ] + y [0 ], x [1 ] + y [1 ]), gather ((loss , metric )))
130
+ loss = loss / dist .get_world_size ()
127
131
elapsed = datetime .now () - start
128
132
logger .info (f"loss: { loss :.4f} - { metric } " )
129
133
logger .info (f"{ elapsed } s elapsed, { len (dataset )/ elapsed .total_seconds ():.2f} Sents/s" )
@@ -140,7 +144,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
140
144
141
145
logger .info ("Loading the data" )
142
146
dataset = Dataset (self .transform , ** args )
143
- dataset .build (batch_size , buckets , False , False , workers )
147
+ dataset .build (batch_size , buckets , False , dist . is_initialized () , workers )
144
148
logger .info (f"\n { dataset } " )
145
149
146
150
logger .info ("Making predictions on the dataset" )
@@ -154,17 +158,25 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
154
158
f .write (str (s ) + '\n ' )
155
159
elapsed = datetime .now () - start
156
160
161
+ if dist .is_initialized ():
162
+ dist .barrier ()
163
+ if args .cache :
164
+ tdirs = gather (t ) if dist .is_initialized () else (t ,)
157
165
if pred is not None and is_master ():
158
166
logger .info (f"Saving predicted results to { pred } " )
159
167
with open (pred , 'w' ) as f :
160
168
# merge all predictions into one single file
161
169
if args .cache :
162
- for s in progress_bar (sorted (os .listdir (t ), key = lambda x : int (x ))):
163
- with open (os .path .join (t , s )) as s :
170
+ sentences = (os .path .join (i , s ) for i in tdirs for s in os .listdir (i ))
171
+ for i in progress_bar (sorted (sentences , key = lambda x : int (os .path .basename (x )))):
172
+ with open (i ) as s :
164
173
shutil .copyfileobj (s , f )
165
174
else :
166
175
for s in progress_bar (dataset ):
167
176
f .write (str (s ) + '\n ' )
177
+ # exit util all files have been merged
178
+ if dist .is_initialized ():
179
+ dist .barrier ()
168
180
logger .info (f"{ elapsed } s elapsed, { len (dataset ) / elapsed .total_seconds ():.2f} Sents/s" )
169
181
170
182
if not cache :
0 commit comments