Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
stefantaubert committed Jun 13, 2023
1 parent 4368f0b commit d65b96d
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 7 deletions.
26 changes: 21 additions & 5 deletions src/tts_mos_test_mturk/evaluation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from ordered_set import OrderedSet

from tts_mos_test_mturk.io import load_obj, save_obj
from tts_mos_test_mturk.masking.masks import MaskBase
from tts_mos_test_mturk.masking.masks import MaskBase, get_mask_name_and_reverse
from tts_mos_test_mturk.result import Result, Worker
from tts_mos_test_mturk.typing import MaskName


class EvaluationData():
Expand Down Expand Up @@ -79,15 +80,30 @@ def n_workers(self) -> int:
def n_files(self) -> int:
return len(self.files)

def get_mask(self, mask_name: str) -> MaskBase:
def get_mask(self, mask_name: MaskName) -> MaskBase:
mask_name, reverse = get_mask_name_and_reverse(mask_name)
if mask_name not in self.masks:
raise ValueError(f"Mask \"{mask_name}\" doesn't exist!")
mask = self.masks[mask_name]
if reverse:
new_mask = mask.clone()
new_mask.reverse()
return new_mask
return mask

def get_masks_from_names(self, mask_names: Set[MaskName]) -> List[MaskBase]:
masks = [self.get_mask(mask_name) for mask_name in mask_names]
return masks

def get_mask_old(self, mask_name: MaskName) -> MaskBase:
if mask_name not in self.masks:
raise ValueError(f"Mask \"{mask_name}\" doesn't exist!")
return self.masks[mask_name]

def get_masks_from_names(self, mask_names: Set[str]) -> List[MaskBase]:
masks = [self.get_mask(mask_name) for mask_name in mask_names]
def get_masks_from_names_old(self, mask_names: Set[MaskName]) -> List[MaskBase]:
masks = [self.get_mask_old(mask_name) for mask_name in mask_names]
return masks

def add_or_update_mask(self, name: str, mask: MaskBase) -> None:
def add_or_update_mask(self, name: MaskName, mask: MaskBase) -> None:
assert name is not None
self.masks[name] = mask
11 changes: 11 additions & 0 deletions src/tts_mos_test_mturk/masking/masks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Tuple

import numpy as np

REVERSE_INDICATOR = "!"


class MaskBase():
def __init__(self, mask: np.ndarray) -> None:
Expand Down Expand Up @@ -68,3 +72,10 @@ def clone(self) -> "AssignmentsMask":
class WorkersMask(MaskBase):
def clone(self) -> "WorkersMask":
return WorkersMask(self.mask.copy())


def get_mask_name_and_reverse(mask_name: str) -> Tuple[str, bool]:
if mask_name.startswith(REVERSE_INDICATOR):
mask_name = mask_name[len(REVERSE_INDICATOR):]
return mask_name, True
return mask_name, False
8 changes: 8 additions & 0 deletions src/tts_mos_test_mturk_cli/argparse_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ordered_set import OrderedSet

from tts_mos_test_mturk.evaluation_data import EvaluationData
from tts_mos_test_mturk.masking.masks import REVERSE_INDICATOR

T = TypeVar("T")

Expand Down Expand Up @@ -123,6 +124,13 @@ def parse_non_empty_or_whitespace(value: str) -> str:
return value


def parse_output_mask_name(value: str) -> str:
value = parse_non_empty_or_whitespace(value)
if value.startswith(REVERSE_INDICATOR):
raise ArgumentTypeError(f"Value must not start with \"{REVERSE_INDICATOR}\"!")
return value


def parse_float(value: str) -> float:
value = parse_required(value)
try:
Expand Down
5 changes: 3 additions & 2 deletions src/tts_mos_test_mturk_cli/default_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from tts_mos_test_mturk_cli.argparse_helper import (ConvertToSetAction,
parse_non_empty_or_whitespace, parse_project)
parse_non_empty_or_whitespace,
parse_output_mask_name, parse_project)


def add_req_ratings_argument(parser: ArgumentParser) -> None:
Expand All @@ -20,7 +21,7 @@ def add_opt_masks_argument(parser: ArgumentParser) -> None:


def add_req_output_mask_argument(parser: ArgumentParser) -> None:
parser.add_argument("output_mask", type=parse_non_empty_or_whitespace,
parser.add_argument("output_mask", type=parse_output_mask_name,
metavar="OUTPUT-MASK", help="name of the output mask")


Expand Down
2 changes: 2 additions & 0 deletions src/tts_mos_test_mturk_cli/parsers/masks_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ def main(ns: Namespace) -> None:
save_project(ns.project)
return main


def init_reverse_masks_parser(parser: ArgumentParser):
# can be removed bc !maskname does the same thing
parser.description = "Reverse mask."
add_req_project_argument(parser)
parser.add_argument("mask", type=parse_non_empty_or_whitespace,
Expand Down
2 changes: 2 additions & 0 deletions src/tts_mos_test_mturk_cli/validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Set

from tts_mos_test_mturk.evaluation_data import EvaluationData
from tts_mos_test_mturk.masking.masks import get_mask_name_and_reverse
from tts_mos_test_mturk_cli.types import CLIError


def ensure_mask_exists(data: EvaluationData, mask: str) -> None:
mask, _ = get_mask_name_and_reverse(mask)
if mask not in data.masks:
raise CLIError(f"Mask \"{mask}\" doesn't exist!")

Expand Down

0 comments on commit d65b96d

Please sign in to comment.