@@ -124,7 +124,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
124
124
return super ().predict (** Config ().update (locals ()))
125
125
126
126
@classmethod
127
- def load (cls , path , reload = False , ** kwargs ):
127
+ def load (cls , path , reload = False , src = None , ** kwargs ):
128
128
r"""
129
129
Loads a parser with data fields and pretrained model parameters.
130
130
@@ -135,6 +135,11 @@ def load(cls, path, reload=False, **kwargs):
135
135
- a local path to a pretrained model, e.g., ``./<path>/model``.
136
136
reload (bool):
137
137
Whether to discard the existing cache and force a fresh download. Default: ``False``.
138
+ src (str):
139
+ Specifies where to download the model.
140
+ ``'github'``: github release page.
141
+ ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
142
+ Default: None.
138
143
kwargs (dict):
139
144
A dict holding unconsumed arguments for updating training configs and initializing the model.
140
145
@@ -144,7 +149,7 @@ def load(cls, path, reload=False, **kwargs):
144
149
>>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
145
150
"""
146
151
147
- return super ().load (path , reload , ** kwargs )
152
+ return super ().load (path , reload , src , ** kwargs )
148
153
149
154
def _train (self , loader ):
150
155
self .model .train ()
@@ -425,7 +430,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
425
430
return super ().predict (** Config ().update (locals ()))
426
431
427
432
@classmethod
428
- def load (cls , path , reload = False , ** kwargs ):
433
+ def load (cls , path , reload = False , src = None , ** kwargs ):
429
434
r"""
430
435
Loads a parser with data fields and pretrained model parameters.
431
436
@@ -436,6 +441,11 @@ def load(cls, path, reload=False, **kwargs):
436
441
- a local path to a pretrained model, e.g., ``./<path>/model``.
437
442
reload (bool):
438
443
Whether to discard the existing cache and force a fresh download. Default: ``False``.
444
+ src (str):
445
+ Specifies where to download the model.
446
+ ``'github'``: github release page.
447
+ ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
448
+ Default: None.
439
449
kwargs (dict):
440
450
A dict holding unconsumed arguments for updating training configs and initializing the model.
441
451
@@ -445,7 +455,7 @@ def load(cls, path, reload=False, **kwargs):
445
455
>>> parser = Parser.load('./ptb.crf.dep.lstm.char')
446
456
"""
447
457
448
- return super ().load (path , reload , ** kwargs )
458
+ return super ().load (path , reload , src , ** kwargs )
449
459
450
460
def _train (self , loader ):
451
461
self .model .train ()
@@ -636,7 +646,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
636
646
return super ().predict (** Config ().update (locals ()))
637
647
638
648
@classmethod
639
- def load (cls , path , reload = False , ** kwargs ):
649
+ def load (cls , path , reload = False , src = None , ** kwargs ):
640
650
r"""
641
651
Loads a parser with data fields and pretrained model parameters.
642
652
@@ -647,6 +657,11 @@ def load(cls, path, reload=False, **kwargs):
647
657
- a local path to a pretrained model, e.g., ``./<path>/model``.
648
658
reload (bool):
649
659
Whether to discard the existing cache and force a fresh download. Default: ``False``.
660
+ src (str):
661
+ Specifies where to download the model.
662
+ ``'github'``: github release page.
663
+ ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
664
+ Default: None.
650
665
kwargs (dict):
651
666
A dict holding unconsumed arguments for updating training configs and initializing the model.
652
667
@@ -656,7 +671,7 @@ def load(cls, path, reload=False, **kwargs):
656
671
>>> parser = Parser.load('./ptb.crf2o.dep.lstm.char')
657
672
"""
658
673
659
- return super ().load (path , reload , ** kwargs )
674
+ return super ().load (path , reload , src , ** kwargs )
660
675
661
676
def _train (self , loader ):
662
677
self .model .train ()
@@ -933,7 +948,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
933
948
return super ().predict (** Config ().update (locals ()))
934
949
935
950
@classmethod
936
- def load (cls , path , reload = False , ** kwargs ):
951
+ def load (cls , path , reload = False , src = None , ** kwargs ):
937
952
r"""
938
953
Loads a parser with data fields and pretrained model parameters.
939
954
@@ -944,6 +959,11 @@ def load(cls, path, reload=False, **kwargs):
944
959
- a local path to a pretrained model, e.g., ``./<path>/model``.
945
960
reload (bool):
946
961
Whether to discard the existing cache and force a fresh download. Default: ``False``.
962
+ src (str):
963
+ Specifies where to download the model.
964
+ ``'github'``: github release page.
965
+ ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
966
+ Default: None.
947
967
kwargs (dict):
948
968
A dict holding unconsumed arguments for updating training configs and initializing the model.
949
969
@@ -953,7 +973,7 @@ def load(cls, path, reload=False, **kwargs):
953
973
>>> parser = Parser.load('./ptb.vi.dep.lstm.char')
954
974
"""
955
975
956
- return super ().load (path , reload , ** kwargs )
976
+ return super ().load (path , reload , src , ** kwargs )
957
977
958
978
def _train (self , loader ):
959
979
self .model .train ()
0 commit comments