Skip to content

Commit 76a8fb1

Browse files
committed
Fix if condition error for dist
1 parent 7f1a093 commit 76a8fb1

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

supar/parsers/parser.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
7171
start = datetime.now()
7272

7373
logger.info(f"Epoch {epoch} / {args.epochs}:")
74-
with self.model.join():
74+
if dist.is_initialized():
75+
with self.model.join():
76+
self._train(train.loader)
77+
else:
7578
self._train(train.loader)
7679
loss, dev_metric = self._evaluate(dev.loader)
7780
logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}")
@@ -92,7 +95,8 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
9295
logger.info(f"{t}s elapsed\n")
9396
if self.patience < 1:
9497
break
95-
dist.barrier()
98+
if dist.is_initialized():
99+
dist.barrier()
96100
args.device = args.local_rank
97101
parser = self.load(**args)
98102
loss, metric = parser._evaluate(test.loader)
@@ -165,7 +169,7 @@ def build(cls, path, **kwargs):
165169
raise NotImplementedError
166170

167171
@classmethod
168-
def load(cls, path, reload=False, src='github', checkpoint=False, device=None, **kwargs):
172+
def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
169173
r"""
170174
Loads a parser with data fields and pretrained model parameters.
171175
@@ -183,9 +187,6 @@ def load(cls, path, reload=False, src='github', checkpoint=False, device=None, *
183187
Default: ``'github'``.
184188
checkpoint (bool):
185189
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
186-
device (:class:`torch.device`):
187-
The desired device of the model parameters.
188-
If ``None``, uses the default GPU device (if available), otherwise uses the CPU device. Default: ``None``.
189190
kwargs (dict):
190191
A dict holding unconsumed arguments for updating training configs and initializing the model.
191192
@@ -196,11 +197,10 @@ def load(cls, path, reload=False, src='github', checkpoint=False, device=None, *
196197
"""
197198

198199
args = Config(**locals())
199-
if args.device is None:
200-
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
200+
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
201201
if not os.path.exists(path):
202202
path = download(supar.MODEL[src].get(path, path), reload=reload)
203-
state = torch.load(path)
203+
state = torch.load(path, map_location='cpu')
204204
cls = supar.PARSER[state['name']] if cls.NAME is None else cls
205205
args = state['args'].update(args)
206206
model = cls.MODEL(**args)

0 commit comments

Comments
 (0)