Skip to content

Commit 5e437cb

Browse files
committed
Support (de-)binarize bytes
1 parent 4becaa3 commit 5e437cb

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

supar/utils/fn.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
from omegaconf import DictConfig, OmegaConf
1717
from supar.utils.common import CACHE
18-
from supar.utils.logging import progress_bar
1918

2019

2120
def ispunct(token: str) -> bool:
@@ -291,11 +290,11 @@ def extract(path: str, reload: bool = False, clean: bool = False) -> str:
291290
return extracted
292291

293292

294-
def binarize(data: Iterable, fbin: str = None) -> None:
293+
def binarize(data: Iterable, fbin: str = None, byte: bool = False) -> None:
295294
start, meta = 0, []
296295
with open(fbin, 'wb') as f:
297-
for s in progress_bar(data):
298-
bytes = pickle.dumps(s)
296+
for s in data:
297+
bytes = pickle.dumps(s) if not byte else s
299298
f.write(bytes)
300299
meta.append((start, len(bytes)))
301300
start = start + len(bytes)
@@ -306,15 +305,16 @@ def binarize(data: Iterable, fbin: str = None) -> None:
306305
f.write(pickle.dumps(torch.tensor((start, start + len(meta)))))
307306

308307

309-
def debinarize(fbin: str, offset: Optional[int] = 0, length: Optional[int] = 0, meta: bool = False) -> Any:
310-
with open(fbin, 'rb') as f:
311-
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
312-
if meta:
313-
length = len(pickle.dumps(torch.tensor((offset, length))))
314-
mm.seek(-length, os.SEEK_END)
315-
offset, length = pickle.loads(mm.read(length)).tolist()
316-
mm.seek(offset)
317-
return pickle.loads(mm.read(length))
308+
def debinarize(fbin: str, position: Optional[Tuple[int, int]] = (0, 0), meta: bool = False, byte: bool = False) -> Any:
309+
offset, length = position
310+
with open(fbin, 'rb') as f, mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
311+
if meta:
312+
length = len(pickle.dumps(torch.tensor(position)))
313+
mm.seek(-length, os.SEEK_END)
314+
offset, length = pickle.loads(mm.read(length)).tolist()
315+
mm.seek(offset)
316+
bytes = mm.read(length)
317+
return bytes if byte else pickle.loads(bytes)
318318

319319

320320
def resolve_config(args: Union[Dict, DictConfig]) -> DictConfig:

0 commit comments

Comments
 (0)