Skip to content

Commit fc564a5

Browse files
committed
Add device arg for Parser.load
1 parent d7bd7df commit fc564a5

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

supar/parsers/parser.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
9292
if self.patience < 1:
9393
break
9494
dist.barrier()
95+
args.device = args.local_rank
9596
parser = self.load(**args)
9697
loss, metric = parser._evaluate(test.loader)
9798
# only allow the master device to save models
@@ -163,7 +164,7 @@ def build(cls, path, **kwargs):
163164
raise NotImplementedError
164165

165166
@classmethod
166-
def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
167+
def load(cls, path, reload=False, src='github', checkpoint=False, device=None, **kwargs):
167168
r"""
168169
Loads a parser with data fields and pretrained model parameters.
169170
@@ -181,6 +182,9 @@ def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
181182
Default: ``'github'``.
182183
checkpoint (bool):
183184
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
185+
device (:class:`torch.device`):
186+
The desired device of the model parameters.
187+
If ``None``, uses the default GPU device (if available), otherwise uses the CPU device. Default: ``None``.
184188
kwargs (dict):
185189
A dict holding unconsumed arguments for updating training configs and initializing the model.
186190
@@ -191,17 +195,20 @@ def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
191195
"""
192196

193197
args = Config(**locals())
194-
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
195-
state = torch.load(path if os.path.exists(path) else download(supar.MODEL[src].get(path, path), reload=reload))
198+
if device is None:
199+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
200+
if not os.path.exists(path):
201+
path = download(supar.MODEL[src].get(path, path), reload=reload)
202+
state = torch.load(path)
196203
cls = supar.PARSER[state['name']] if cls.NAME is None else cls
197204
args = state['args'].update(args)
198205
model = cls.MODEL(**args)
199206
model.load_pretrained(state['pretrained'])
200207
model.load_state_dict(state['state_dict'], False)
201-
model.to(args.device)
208+
model.to(device)
202209
transform = state['transform']
203210
parser = cls(args, model, transform)
204-
parser.checkpoint_state_dict = state['checkpoint_state_dict'] if args.checkpoint else None
211+
parser.checkpoint_state_dict = state['checkpoint_state_dict'] if checkpoint else None
205212
return parser
206213

207214
def save(self, path):

0 commit comments

Comments
 (0)