Skip to content

Commit d48cda2

Browse files
chaemonkyuridenamida
authored andcommitted
小数での出力による誤差ジャッジ対応 (kyuridenamida#156)
* 小数における誤差ジャッジに対応 * constatns_prediction.pyのdecimalを大文字に, 10^9+7を削除 * 細かな表記を修正 * 指摘事項の容易に反映できるところを反映 * JudgeTypeをEnumに。Judge classをNormalJudge, DecimalJudgeに分離 * flake8のエラーを修正 * decimal caseに対応するように修正 * tester.pyの表記を修正 * Decimal Testを追加。ErrorTypeをEnumに * decimal testの名称変更、judgetypeをoutput, expectedに統一 * テスト * ディレクトリ名を変更
1 parent 1126750 commit d48cda2

File tree

31 files changed

+317
-30
lines changed

31 files changed

+317
-30
lines changed

atcodertools/common/judgetype.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/usr/bin/python3
2+
# -*- coding: utf-8 -*-
3+
4+
from enum import Enum
5+
6+
7+
class JudgeType(Enum):
8+
Normal = "normal"
9+
Decimal = "decimal"
10+
Other = "other"
11+
12+
13+
class ErrorType(Enum):
14+
Absolute = "absolute"
15+
Relative = "relative"
16+
AbsoluteOrRelative = "absolute_or_relative"
17+
18+
19+
class Judge:
20+
def verify(self, output, expected):
21+
pass
22+
23+
24+
class NormalJudge(Judge):
25+
def __init__(self):
26+
self.judge_type = JudgeType.Normal
27+
28+
def verify(self, output, expected):
29+
return output == expected
30+
31+
def to_dict(self):
32+
return {
33+
"judge_type": self.judge_type.value,
34+
}
35+
36+
@classmethod
37+
def from_dict(cls, dic):
38+
r = NormalJudge()
39+
return r
40+
41+
42+
class DecimalJudge(Judge):
43+
def __init__(self,
44+
error_type: ErrorType = ErrorType.AbsoluteOrRelative,
45+
diff: float = 0.0
46+
):
47+
self.judge_type = JudgeType.Decimal
48+
self.error_type = error_type
49+
self.diff = diff
50+
51+
def __verify_sub(self, output, expected: float) -> bool:
52+
if self.error_type in [ErrorType.Absolute, ErrorType.AbsoluteOrRelative] and abs(expected - output) <= self.diff:
53+
return True
54+
if self.error_type in [ErrorType.Relative, ErrorType.AbsoluteOrRelative] and abs((expected - output) / expected) <= self.diff:
55+
return True
56+
return False
57+
58+
def verify(self, output, expected) -> bool:
59+
output = output.strip().split()
60+
expected = expected.strip().split()
61+
if len(output) != len(expected):
62+
return False
63+
for i in range(0, len(output)):
64+
if not self.__verify_sub(float(output[i]), float(expected[i])):
65+
return False
66+
return True
67+
68+
def to_dict(self):
69+
return {
70+
"judge_type": self.judge_type.value,
71+
"error_type": self.error_type.value,
72+
"diff": self.diff
73+
}
74+
75+
@classmethod
76+
def from_dict(cls, dic):
77+
r = DecimalJudge(
78+
diff=dic["diff"]
79+
)
80+
r.error_type = ErrorType(dic["error_type"])
81+
return r
82+
83+
84+
class OtherJudge(Judge):
85+
# dummy
86+
pass

atcodertools/constprediction/constants_prediction.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from atcodertools.constprediction.models.problem_constant_set import ProblemConstantSet
77
from atcodertools.client.models.problem_content import ProblemContent, InputFormatDetectionError, SampleDetectionError
88
from atcodertools.common.logging import logger
9+
from atcodertools.common.judgetype import JudgeType, ErrorType, NormalJudge, DecimalJudge
910

1011

1112
class YesNoPredictionFailedError(Exception):
@@ -18,13 +19,28 @@ def __init__(self, cands):
1819
self.cands = cands
1920

2021

22+
class MultipleDecimalCandidatesError(Exception):
23+
24+
def __init__(self, cands):
25+
self.cands = cands
26+
27+
2128
MOD_ANCHORS = ["余り", "あまり", "mod", "割っ", "modulo"]
29+
DECIMAL_ANCHORS = ["誤差", " error "]
2230

2331
MOD_STRATEGY_RE_LIST = [
2432
re.compile("([0-9]+).?.?.?で割った"),
2533
re.compile("modu?l?o?[^0-9]?[^0-9]?[^0-9]?([0-9]+)")
2634
]
2735

36+
DECIMAL_STRATEGY_RE_LIST_KEYWORD = [
37+
re.compile("(?:絶対|相対)誤差"),
38+
re.compile("(?:absolute|relative)")
39+
]
40+
DECIMAL_STRATEGY_RE_LIST_VAL = [
41+
re.compile("10\^(-[0-9]+)"),
42+
]
43+
2844

2945
def is_mod_context(sentence):
3046
for kw in MOD_ANCHORS:
@@ -33,6 +49,13 @@ def is_mod_context(sentence):
3349
return False
3450

3551

52+
def is_decimal_context(sentence):
53+
for kw in DECIMAL_ANCHORS:
54+
if kw in sentence:
55+
return True
56+
return False
57+
58+
3659
def predict_modulo(html: str) -> Optional[int]:
3760
def normalize(sentence):
3861
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").replace(
@@ -83,6 +106,57 @@ def predict_yes_no(html: str) -> Tuple[Optional[str], Optional[str]]:
83106
return yes_str, no_str
84107

85108

109+
def predict_judge_type(html: str) -> Optional[JudgeType]:
110+
def normalize(sentence):
111+
return sentence.replace('\\', '').replace("{", "").replace("}", "").replace(",", "").replace(" ", "").lower().strip()
112+
113+
soup = BeautifulSoup(html, "html.parser")
114+
sentences = soup.get_text().split("\n")
115+
sentences = [normalize(s) for s in sentences if is_decimal_context(s)]
116+
117+
decimal_keyword_cands = set()
118+
decimal_val_cands = set()
119+
120+
if len(sentences) > 0: # Decimal
121+
is_absolute = False
122+
is_relative = False
123+
for s in sentences:
124+
for regexp in DECIMAL_STRATEGY_RE_LIST_KEYWORD:
125+
r = regexp.findall(s)
126+
for t in r:
127+
if t == "絶対誤差" or t == "absolute":
128+
is_absolute = True
129+
elif t == "相対誤差" or t == "relative":
130+
is_relative = True
131+
decimal_keyword_cands.add(t)
132+
for s in sentences:
133+
for regexp in DECIMAL_STRATEGY_RE_LIST_VAL:
134+
r = regexp.findall(s)
135+
for t in r:
136+
decimal_val_cands.add(int(t))
137+
138+
if len(decimal_val_cands) == 0:
139+
return None
140+
141+
if len(decimal_val_cands) == 1:
142+
if is_absolute:
143+
if is_relative:
144+
error_type = ErrorType.AbsoluteOrRelative
145+
else:
146+
error_type = ErrorType.Absolute
147+
else:
148+
if is_relative:
149+
error_type = ErrorType.Relative
150+
else:
151+
assert(False)
152+
153+
return DecimalJudge(error_type, 10.0**(int(list(decimal_val_cands)[0])))
154+
155+
raise MultipleDecimalCandidatesError(decimal_val_cands)
156+
157+
return NormalJudge()
158+
159+
86160
def predict_constants(html: str) -> ProblemConstantSet:
87161
try:
88162
yes_str, no_str = predict_yes_no(html)
@@ -96,4 +170,11 @@ def predict_constants(html: str) -> ProblemConstantSet:
96170
"two or more candidates {} are detected as modulo values".format(e.cands))
97171
mod = None
98172

99-
return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str)
173+
try:
174+
judge_type = predict_judge_type(html)
175+
except MultipleModCandidatesError as e:
176+
logger.warning("decimal prediction failed -- "
177+
"two or more candidates {} are detected as decimal values".format(e.cands))
178+
judge_type = NormalJudge()
179+
180+
return ProblemConstantSet(mod=mod, yes_str=yes_str, no_str=no_str, judge_type=judge_type)

atcodertools/constprediction/models/problem_constant_set.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from atcodertools.common.judgetype import JudgeType
12

23

34
class ProblemConstantSet:
@@ -6,7 +7,9 @@ def __init__(self,
67
mod: int = None,
78
yes_str: str = None,
89
no_str: str = None,
10+
judge_type: JudgeType = None,
911
):
1012
self.mod = mod
1113
self.yes_str = yes_str
1214
self.no_str = no_str
15+
self.judge_type = judge_type

atcodertools/executils/run_program.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ def __init__(self, status: ExecStatus, output: str = None, stderr: str = None, e
2121
else:
2222
self.elapsed_ms = None
2323

24-
def is_correct_output(self, answer_text):
25-
return self.status == ExecStatus.NORMAL and answer_text == self.output
24+
def is_correct_output(self, answer_text, judge_type):
25+
if self.status != ExecStatus.NORMAL:
26+
return False
27+
return judge_type.verify(self.output, answer_text)
2628

2729
def has_stderr(self):
2830
return len(self.stderr) > 0

atcodertools/tools/envgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def emit_info(text):
139139
config.etc_config.in_example_format.replace("{}", "*"),
140140
config.etc_config.out_example_format.replace("{}", "*"),
141141
lang,
142+
constants.judge_type,
142143
).save_to(metadata_path)
143144
emit_info("Saved metadata to {}".format(metadata_path))
144145

atcodertools/tools/models/metadata.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22

33
from atcodertools.client.models.problem import Problem
44
from atcodertools.common.language import Language
5+
from atcodertools.common.judgetype import NormalJudge, DecimalJudge
56

67

78
class Metadata:
89

9-
def __init__(self, problem: Problem, code_filename: str, sample_in_pattern: str, sample_out_pattern: str, lang: Language):
10+
def __init__(self, problem: Problem, code_filename: str, sample_in_pattern: str, sample_out_pattern: str, lang: Language, judge_type=NormalJudge()):
1011
self.problem = problem
1112
self.code_filename = code_filename
1213
self.sample_in_pattern = sample_in_pattern
1314
self.sample_out_pattern = sample_out_pattern
1415
self.lang = lang
16+
self.judge_type = judge_type
1517

1618
def to_dict(self):
1719
return {
@@ -20,16 +22,29 @@ def to_dict(self):
2022
"sample_in_pattern": self.sample_in_pattern,
2123
"sample_out_pattern": self.sample_out_pattern,
2224
"lang": self.lang.name,
25+
"judge": self.judge_type.to_dict(),
2326
}
2427

2528
@classmethod
2629
def from_dict(cls, dic):
30+
if "judge" in dic:
31+
judge_type = dic["judge"]["judge_type"]
32+
if judge_type == "normal":
33+
judge = NormalJudge.from_dict(dic["judge"])
34+
elif judge_type == "decimal":
35+
judge = DecimalJudge.from_dict(dic["judge"])
36+
else:
37+
raise Exception("invalid judge type")
38+
else:
39+
judge = NormalJudge()
40+
2741
return Metadata(
2842
problem=Problem.from_dict(dic["problem"]),
2943
code_filename=dic["code_filename"],
3044
sample_in_pattern=dic["sample_in_pattern"],
3145
sample_out_pattern=dic["sample_out_pattern"],
3246
lang=Language.from_name(dic["lang"]),
47+
judge_type=judge
3348
)
3449

3550
@classmethod
@@ -40,3 +55,4 @@ def load_from(cls, filename):
4055
def save_to(self, filename):
4156
with open(filename, 'w') as f:
4257
json.dump(self.to_dict(), f, indent=1, sort_keys=True)
58+
f.write('\n')

0 commit comments

Comments
 (0)