|
1 |
| - |
2 |
| -from typing import Dict, Any, Optional |
3 |
| - |
4 |
| -from atcodertools.codegen.code_style_config import CodeStyleConfig |
5 | 1 | from atcodertools.codegen.models.code_gen_args import CodeGenArgs
|
6 | 2 | from atcodertools.codegen.template_engine import render
|
7 |
| -from atcodertools.fmtprediction.models.format import Pattern, SingularPattern, ParallelPattern, TwoDimensionalPattern, \ |
8 |
| - Format |
9 |
| -from atcodertools.fmtprediction.models.type import Type |
10 |
| -from atcodertools.fmtprediction.models.variable import Variable |
11 |
| - |
12 |
| - |
13 |
| -def _loop_header(var: Variable, for_second_index: bool): |
14 |
| - if for_second_index: |
15 |
| - index = var.second_index |
16 |
| - loop_var = "j" |
17 |
| - else: |
18 |
| - index = var.first_index |
19 |
| - loop_var = "i" |
20 |
| - |
21 |
| - return "for(int {loop_var} = 0 ; {loop_var} < {length} ; {loop_var}++){{".format( |
22 |
| - loop_var=loop_var, |
23 |
| - length=index.get_length() |
24 |
| - ) |
25 |
| - |
26 |
| - |
27 |
| -class CppCodeGenerator: |
28 |
| - |
29 |
| - def __init__(self, |
30 |
| - format_: Optional[Format[Variable]], |
31 |
| - config: CodeStyleConfig): |
32 |
| - self._format = format_ |
33 |
| - self._config = config |
34 |
| - |
35 |
| - def generate_parameters(self) -> Dict[str, Any]: |
36 |
| - if self._format is None: |
37 |
| - return dict(prediction_success=False) |
38 |
| - |
39 |
| - return dict(formal_arguments=self._formal_arguments(), |
40 |
| - actual_arguments=self._actual_arguments(), |
41 |
| - input_part=self._input_part(), |
42 |
| - prediction_success=True) |
43 |
| - |
44 |
| - def _input_part(self): |
45 |
| - lines = [] |
46 |
| - for pattern in self._format.sequence: |
47 |
| - lines += self._render_pattern(pattern) |
48 |
| - return "\n{indent}".format(indent=self._indent(1)).join(lines) |
49 |
| - |
50 |
| - def _convert_type(self, type_: Type) -> str: |
51 |
| - if type_ == Type.float: |
52 |
| - return "long double" |
53 |
| - elif type_ == Type.int: |
54 |
| - return "long long" |
55 |
| - elif type_ == Type.str: |
56 |
| - return "std::string" |
57 |
| - else: |
58 |
| - raise NotImplementedError |
59 |
| - |
60 |
| - def _get_declaration_type(self, var: Variable): |
61 |
| - ctype = self._convert_type(var.type) |
62 |
| - for _ in range(var.dim_num()): |
63 |
| - ctype = 'std::vector<{}>'.format(ctype) |
64 |
| - return ctype |
65 |
| - |
66 |
| - def _actual_arguments(self) -> str: |
67 |
| - """ |
68 |
| - :return the string form of actual arguments e.g. "N, K, a" |
69 |
| - """ |
70 |
| - return ", ".join([ |
71 |
| - v.name if v.dim_num() == 0 else 'std::move({})'.format(v.name) |
72 |
| - for v in self._format.all_vars()]) |
73 |
| - |
74 |
| - def _formal_arguments(self): |
75 |
| - """ |
76 |
| - :return the string form of formal arguments e.g. "int N, int K, std::vector<int> a" |
77 |
| - """ |
78 |
| - return ", ".join([ |
79 |
| - "{decl_type} {name}".format( |
80 |
| - decl_type=self._get_declaration_type(v), |
81 |
| - name=v.name) |
82 |
| - for v in self._format.all_vars() |
83 |
| - ]) |
84 |
| - |
85 |
| - def _generate_declaration(self, var: Variable): |
86 |
| - """ |
87 |
| - :return: Create declaration part E.g. array[1..n] -> std::vector<int> array = std::vector<int>(n-1+1); |
88 |
| - """ |
89 |
| - if var.dim_num() == 0: |
90 |
| - dims = [] |
91 |
| - elif var.dim_num() == 1: |
92 |
| - dims = [var.first_index.get_length()] |
93 |
| - elif var.dim_num() == 2: |
94 |
| - dims = [var.first_index.get_length(), |
95 |
| - var.second_index.get_length()] |
96 |
| - else: |
97 |
| - raise NotImplementedError |
98 |
| - |
99 |
| - if len(dims) == 0: |
100 |
| - ctor = '' |
101 |
| - elif len(dims) == 1: |
102 |
| - ctor = '({})'.format(dims[0]) |
103 |
| - else: |
104 |
| - ctor = '({})'.format(dims[-1]) |
105 |
| - ctype = self._convert_type(var.type) |
106 |
| - for dim in dims[-2::-1]: |
107 |
| - ctype = 'std::vector<{}>'.format(ctype) |
108 |
| - ctor = '({}, {}{})'.format(dim, ctype, ctor) |
109 |
| - |
110 |
| - line = "{decl_type} {name}{constructor};".format( |
111 |
| - name=var.name, |
112 |
| - decl_type=self._get_declaration_type(var), |
113 |
| - constructor=ctor |
114 |
| - ) |
115 |
| - return line |
116 |
| - |
117 |
| - def _input_code_for_var(self, var: Variable) -> str: |
118 |
| - name = self._get_var_name(var) |
119 |
| - if var.type == Type.float: |
120 |
| - return 'scanf("%Lf",&{name});'.format(name=name) |
121 |
| - elif var.type == Type.int: |
122 |
| - return 'scanf("%lld",&{name});'.format(name=name) |
123 |
| - elif var.type == Type.str: |
124 |
| - return 'std::cin >> {name};'.format(name=name) |
125 |
| - else: |
126 |
| - raise NotImplementedError |
127 |
| - |
128 |
| - @staticmethod |
129 |
| - def _get_var_name(var: Variable): |
130 |
| - name = var.name |
131 |
| - if var.dim_num() >= 1: |
132 |
| - name += "[i]" |
133 |
| - if var.dim_num() >= 2: |
134 |
| - name += "[j]" |
135 |
| - return name |
136 |
| - |
137 |
| - def _render_pattern(self, pattern: Pattern): |
138 |
| - lines = [] |
139 |
| - for var in pattern.all_vars(): |
140 |
| - lines.append(self._generate_declaration(var)) |
141 |
| - |
142 |
| - representative_var = pattern.all_vars()[0] |
143 |
| - if isinstance(pattern, SingularPattern): |
144 |
| - lines.append(self._input_code_for_var(representative_var)) |
145 |
| - elif isinstance(pattern, ParallelPattern): |
146 |
| - lines.append(_loop_header(representative_var, False)) |
147 |
| - for var in pattern.all_vars(): |
148 |
| - lines.append("{indent}{line}".format(indent=self._indent(1), |
149 |
| - line=self._input_code_for_var(var))) |
150 |
| - lines.append("}") |
151 |
| - elif isinstance(pattern, TwoDimensionalPattern): |
152 |
| - lines.append(_loop_header(representative_var, False)) |
153 |
| - lines.append( |
154 |
| - "{indent}{line}".format(indent=self._indent(1), line=_loop_header(representative_var, True))) |
155 |
| - for var in pattern.all_vars(): |
156 |
| - lines.append("{indent}{line}".format(indent=self._indent(2), |
157 |
| - line=self._input_code_for_var(var))) |
158 |
| - lines.append("{indent}}}".format(indent=self._indent(1))) |
159 |
| - lines.append("}") |
160 |
| - else: |
161 |
| - raise NotImplementedError |
162 |
| - |
163 |
| - return lines |
164 |
| - |
165 |
| - def _indent(self, depth): |
166 |
| - return self._config.indent(depth) |
167 |
| - |
168 | 3 |
|
169 |
| -class NoPredictionResultGiven(Exception): |
170 |
| - pass |
| 4 | +from atcodertools.codegen.code_generators.universal_code_generator import CodeGenerator |
| 5 | +from atcodertools.codegen.code_generators.universal_generator.cpp import CodeGeneratorInfo |
171 | 6 |
|
172 | 7 |
|
173 | 8 | def main(args: CodeGenArgs) -> str:
|
174 |
| - code_parameters = CppCodeGenerator( |
175 |
| - args.format, args.config).generate_parameters() |
| 9 | + code_parameters = CodeGenerator( |
| 10 | + args.format, args.config, CodeGeneratorInfo()).generate_parameters() |
176 | 11 | return render(
|
177 | 12 | args.template,
|
178 | 13 | mod=args.constants.mod,
|
|
0 commit comments