Skip to content

Commit 9b208f6

Browse files
committed
Add hlt src
1 parent de1e1d5 commit 9b208f6

File tree

5 files changed

+66
-21
lines changed

5 files changed

+66
-21
lines changed

supar/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
BiaffineSemanticDependencyParser,
2828
VISemanticDependencyParser]}
2929

30-
SRC = 'https://github.com/yzhangcs/parser/releases/download'
30+
SRC = {'github': 'https://github.com/yzhangcs/parser/releases/download',
31+
'hlt': 'http://hlt.suda.edu.cn/LA/yzhang/supar'}
3132
NAME = {
3233
'biaffine-dep-en': 'ptb.biaffine.dep.lstm.char',
3334
'biaffine-dep-zh': 'ctb7.biaffine.dep.lstm.char',
@@ -48,5 +49,5 @@
4849
'biaffine-sdp-roberta-en': 'dm.biaffine.sdp.roberta',
4950
'biaffine-sdp-electra-zh': 'semeval16.biaffine.sdp.electra'
5051
}
51-
MODEL = {n: f'{SRC}/v{__version__}/{m}.zip' for n, m in NAME.items()}
52-
CONFIG = {n: f'{SRC}/v{__version__}/{m}.ini' for n, m in NAME.items()}
52+
MODEL = {n: f"{SRC['github']}/v{__version__}/{m}.zip" for n, m in NAME.items()}
53+
CONFIG = {n: f"{SRC['github']}/v{__version__}/{m}.ini" for n, m in NAME.items()}

supar/parsers/con.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
129129
return super().predict(**Config().update(locals()))
130130

131131
@classmethod
132-
def load(cls, path, reload=False, **kwargs):
132+
def load(cls, path, reload=False, src=None, **kwargs):
133133
r"""
134134
Loads a parser with data fields and pretrained model parameters.
135135
@@ -140,6 +140,11 @@ def load(cls, path, reload=False, **kwargs):
140140
- a local path to a pretrained model, e.g., ``./<path>/model``.
141141
reload (bool):
142142
Whether to discard the existing cache and force a fresh download. Default: ``False``.
143+
src (str):
144+
Specifies where to download the model.
145+
``'github'``: github release page.
146+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
147+
Default: None.
143148
kwargs (dict):
144149
A dict holding unconsumed arguments for updating training configs and initializing the model.
145150
@@ -149,7 +154,7 @@ def load(cls, path, reload=False, **kwargs):
149154
>>> parser = Parser.load('./ptb.crf.con.lstm.char')
150155
"""
151156

152-
return super().load(path, reload, **kwargs)
157+
return super().load(path, reload, src, **kwargs)
153158

154159
def _train(self, loader):
155160
self.model.train()
@@ -411,7 +416,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
411416
return super().predict(**Config().update(locals()))
412417

413418
@classmethod
414-
def load(cls, path, reload=False, **kwargs):
419+
def load(cls, path, reload=False, src=None, **kwargs):
415420
r"""
416421
Loads a parser with data fields and pretrained model parameters.
417422
@@ -422,6 +427,11 @@ def load(cls, path, reload=False, **kwargs):
422427
- a local path to a pretrained model, e.g., ``./<path>/model``.
423428
reload (bool):
424429
Whether to discard the existing cache and force a fresh download. Default: ``False``.
430+
src (str):
431+
Specifies where to download the model.
432+
``'github'``: github release page.
433+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
434+
Default: None.
425435
kwargs (dict):
426436
A dict holding unconsumed arguments for updating training configs and initializing the model.
427437
@@ -431,7 +441,7 @@ def load(cls, path, reload=False, **kwargs):
431441
>>> parser = Parser.load('./ptb.vi.con.lstm.char')
432442
"""
433443

434-
return super().load(path, reload, **kwargs)
444+
return super().load(path, reload, src, **kwargs)
435445

436446
def _train(self, loader):
437447
self.model.train()

supar/parsers/dep.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
124124
return super().predict(**Config().update(locals()))
125125

126126
@classmethod
127-
def load(cls, path, reload=False, **kwargs):
127+
def load(cls, path, reload=False, src=None, **kwargs):
128128
r"""
129129
Loads a parser with data fields and pretrained model parameters.
130130
@@ -135,6 +135,11 @@ def load(cls, path, reload=False, **kwargs):
135135
- a local path to a pretrained model, e.g., ``./<path>/model``.
136136
reload (bool):
137137
Whether to discard the existing cache and force a fresh download. Default: ``False``.
138+
src (str):
139+
Specifies where to download the model.
140+
``'github'``: github release page.
141+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
142+
Default: None.
138143
kwargs (dict):
139144
A dict holding unconsumed arguments for updating training configs and initializing the model.
140145
@@ -144,7 +149,7 @@ def load(cls, path, reload=False, **kwargs):
144149
>>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
145150
"""
146151

147-
return super().load(path, reload, **kwargs)
152+
return super().load(path, reload, src, **kwargs)
148153

149154
def _train(self, loader):
150155
self.model.train()
@@ -425,7 +430,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
425430
return super().predict(**Config().update(locals()))
426431

427432
@classmethod
428-
def load(cls, path, reload=False, **kwargs):
433+
def load(cls, path, reload=False, src=None, **kwargs):
429434
r"""
430435
Loads a parser with data fields and pretrained model parameters.
431436
@@ -436,6 +441,11 @@ def load(cls, path, reload=False, **kwargs):
436441
- a local path to a pretrained model, e.g., ``./<path>/model``.
437442
reload (bool):
438443
Whether to discard the existing cache and force a fresh download. Default: ``False``.
444+
src (str):
445+
Specifies where to download the model.
446+
``'github'``: github release page.
447+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
448+
Default: None.
439449
kwargs (dict):
440450
A dict holding unconsumed arguments for updating training configs and initializing the model.
441451
@@ -445,7 +455,7 @@ def load(cls, path, reload=False, **kwargs):
445455
>>> parser = Parser.load('./ptb.crf.dep.lstm.char')
446456
"""
447457

448-
return super().load(path, reload, **kwargs)
458+
return super().load(path, reload, src, **kwargs)
449459

450460
def _train(self, loader):
451461
self.model.train()
@@ -636,7 +646,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
636646
return super().predict(**Config().update(locals()))
637647

638648
@classmethod
639-
def load(cls, path, reload=False, **kwargs):
649+
def load(cls, path, reload=False, src=None, **kwargs):
640650
r"""
641651
Loads a parser with data fields and pretrained model parameters.
642652
@@ -647,6 +657,11 @@ def load(cls, path, reload=False, **kwargs):
647657
- a local path to a pretrained model, e.g., ``./<path>/model``.
648658
reload (bool):
649659
Whether to discard the existing cache and force a fresh download. Default: ``False``.
660+
src (str):
661+
Specifies where to download the model.
662+
``'github'``: github release page.
663+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
664+
Default: None.
650665
kwargs (dict):
651666
A dict holding unconsumed arguments for updating training configs and initializing the model.
652667
@@ -656,7 +671,7 @@ def load(cls, path, reload=False, **kwargs):
656671
>>> parser = Parser.load('./ptb.crf2o.dep.lstm.char')
657672
"""
658673

659-
return super().load(path, reload, **kwargs)
674+
return super().load(path, reload, src, **kwargs)
660675

661676
def _train(self, loader):
662677
self.model.train()
@@ -933,7 +948,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
933948
return super().predict(**Config().update(locals()))
934949

935950
@classmethod
936-
def load(cls, path, reload=False, **kwargs):
951+
def load(cls, path, reload=False, src=None, **kwargs):
937952
r"""
938953
Loads a parser with data fields and pretrained model parameters.
939954
@@ -944,6 +959,11 @@ def load(cls, path, reload=False, **kwargs):
944959
- a local path to a pretrained model, e.g., ``./<path>/model``.
945960
reload (bool):
946961
Whether to discard the existing cache and force a fresh download. Default: ``False``.
962+
src (str):
963+
Specifies where to download the model.
964+
``'github'``: github release page.
965+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
966+
Default: None.
947967
kwargs (dict):
948968
A dict holding unconsumed arguments for updating training configs and initializing the model.
949969
@@ -953,7 +973,7 @@ def load(cls, path, reload=False, **kwargs):
953973
>>> parser = Parser.load('./ptb.vi.dep.lstm.char')
954974
"""
955975

956-
return super().load(path, reload, **kwargs)
976+
return super().load(path, reload, src, **kwargs)
957977

958978
def _train(self, loader):
959979
self.model.train()

supar/parsers/parser.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,14 @@ def build(cls, path, **kwargs):
153153
raise NotImplementedError
154154

155155
@classmethod
156-
def load(cls, path, reload=False, **kwargs):
156+
def load(cls, path, reload=False, src=None, **kwargs):
157157
args = Config(**locals())
158158
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
159-
state = torch.load(path if os.path.exists(path) else download(supar.MODEL.get(path, path), reload=reload))
159+
if src is not None:
160+
links = {n: f"{supar.SRC[src]}/v{supar.__version__}/{m}.zip" for n, m in supar.NAME.items()}
161+
else:
162+
links = supar.MODEL
163+
state = torch.load(path if os.path.exists(path) else download(links.get(path, path), reload=reload))
160164
cls = supar.PARSER[state['name']] if cls.NAME is None else cls
161165
args = state['args'].update(args)
162166
model = cls.MODEL(**args)

supar/parsers/sdp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, verbos
101101
return super().predict(**Config().update(locals()))
102102

103103
@classmethod
104-
def load(cls, path, reload=False, **kwargs):
104+
def load(cls, path, reload=False, src=None, **kwargs):
105105
r"""
106106
Loads a parser with data fields and pretrained model parameters.
107107
@@ -112,6 +112,11 @@ def load(cls, path, reload=False, **kwargs):
112112
- a local path to a pretrained model, e.g., ``./<path>/model``.
113113
reload (bool):
114114
Whether to discard the existing cache and force a fresh download. Default: ``False``.
115+
src (str):
116+
Specifies where to download the model.
117+
``'github'``: github release page.
118+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
119+
Default: None.
115120
kwargs (dict):
116121
A dict holding unconsumed arguments for updating training configs and initializing the model.
117122
@@ -121,7 +126,7 @@ def load(cls, path, reload=False, **kwargs):
121126
>>> parser = Parser.load('./dm.biaffine.sdp.lstm.char')
122127
"""
123128

124-
return super().load(path, reload, **kwargs)
129+
return super().load(path, reload, src, **kwargs)
125130

126131
def _train(self, loader):
127132
self.model.train()
@@ -370,7 +375,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, verbos
370375
return super().predict(**Config().update(locals()))
371376

372377
@classmethod
373-
def load(cls, path, reload=False, **kwargs):
378+
def load(cls, path, reload=False, src=None, **kwargs):
374379
r"""
375380
Loads a parser with data fields and pretrained model parameters.
376381
@@ -381,6 +386,11 @@ def load(cls, path, reload=False, **kwargs):
381386
- a local path to a pretrained model, e.g., ``./<path>/model``.
382387
reload (bool):
383388
Whether to discard the existing cache and force a fresh download. Default: ``False``.
389+
src (str):
390+
Specifies where to download the model.
391+
``'github'``: github release page.
392+
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
393+
Default: None.
384394
kwargs (dict):
385395
A dict holding unconsumed arguments for updating training configs and initializing the model.
386396
@@ -390,7 +400,7 @@ def load(cls, path, reload=False, **kwargs):
390400
>>> parser = Parser.load('./dm.vi.sdp.lstm.char')
391401
"""
392402

393-
return super().load(path, reload, **kwargs)
403+
return super().load(path, reload, src, **kwargs)
394404

395405
def _train(self, loader):
396406
self.model.train()

0 commit comments

Comments
 (0)