5
5
import functools
6
6
import operator
7
7
from enum import Enum
8
+ from utils import lookahead
8
9
9
10
class PacketTypeId (Enum ):
10
11
SUM = 0
@@ -16,57 +17,149 @@ class PacketTypeId(Enum):
16
17
LESS_THAN = 6
17
18
EQUAL = 7
18
19
20
+ class OperatorStyle (Enum ):
21
+ LENGTH = '0'
22
+ COUNT = '1'
23
+
19
24
class Packet (ABC ):
20
- def __init__ (self , version : int , type_id : int ):
25
+ @classmethod
26
+ def decode (cls , transmission : str ) -> "Packet" :
27
+ version = int (transmission [0 :3 ], 2 )
28
+ type_id = int (transmission [3 :6 ], 2 )
29
+ consumed = 6
30
+
31
+ if type_id == PacketTypeId .LITERAL .value :
32
+ packet , new_consumed = LiteralPacket .decode (version , transmission [consumed :])
33
+ consumed += new_consumed
34
+ else :
35
+ packet , new_consumed = OperatorPacket .decode (version , type_id , transmission [consumed :])
36
+ consumed += new_consumed
37
+
38
+ return packet , consumed
39
+
40
+ def __init__ (self , version : int , type_id : PacketTypeId ):
21
41
self .version = version
22
42
self .type_id = type_id
23
43
24
44
@abstractmethod
25
- def value (self ):
45
+ def value (self ) -> int :
26
46
pass
27
47
48
+ def encode (self ) -> str :
49
+ header = bin (self .version )[2 :].zfill (3 ) + bin (self .type_id .value )[2 :].zfill (3 )
50
+ contents = self .encode_contents ()
51
+ return header + contents
52
+
53
+ @abstractmethod
54
+ def encode_contents (self ) -> str :
55
+ pass
56
+
57
+
28
58
class LiteralPacket (Packet ):
59
+ @classmethod
60
+ def decode (cls , version : int , stream : str ) -> "LiteralPacket" :
61
+ offset = 0
62
+ nibbles = []
63
+ while True :
64
+ is_last = stream [offset ] == '0'
65
+ nibbles .append (int (stream [offset + 1 :offset + 5 ], 2 ))
66
+ offset += 5
67
+ if is_last :
68
+ break
69
+
70
+ literal = 0
71
+ for i in range (len (nibbles )):
72
+ literal += nibbles [i ]<< ((len (nibbles )- i - 1 )* 4 )
73
+
74
+ return LiteralPacket (version , literal ), offset
75
+
29
76
def __init__ (self , version : int , literal : int ):
30
- super ().__init__ (version , PacketTypeId .LITERAL . value )
77
+ super ().__init__ (version , PacketTypeId .LITERAL )
31
78
self .literal = literal
32
79
33
80
def value (self ) -> int :
34
81
return self .literal
35
82
83
+ def encode_contents (self ) -> str :
84
+ outs = []
85
+ value = self .literal
86
+ while value :
87
+ outs .append (bin (value & 15 )[2 :].zfill (4 ))
88
+ value = value >> 4
89
+
90
+ out = ''
91
+ for digit , more in lookahead (outs [::- 1 ]):
92
+ if more :
93
+ out += '1'
94
+ else :
95
+ out += '0'
96
+ out += digit
97
+
98
+ return out
99
+
36
100
def __str__ (self ) -> str :
37
101
return f'Literal<{ self .version } , { self .type_id } , { self .literal } >'
38
102
39
103
class OperatorPacket (Packet ):
40
- def __init__ (self , version : int , type_id : int , subpackets : List [Packet ]):
104
+
105
+ @classmethod
106
+ def decode (cls , version : int , type_id : PacketTypeId , stream : str ) -> "OperatorPacket" :
107
+ packets = []
108
+ style = None
109
+ if stream [0 ] == OperatorStyle .LENGTH .value :
110
+ style = OperatorStyle .LENGTH
111
+ consumed = 16
112
+ subpacket_length = int (stream [1 :16 ], 2 )
113
+
114
+ while subpacket_length > 0 :
115
+ packet , new_consumed = Packet .decode (stream [consumed :])
116
+ consumed += new_consumed
117
+ subpacket_length -= new_consumed
118
+ packets .append (packet )
119
+ elif stream [0 ] == OperatorStyle .COUNT .value :
120
+ style = OperatorStyle .COUNT
121
+ consumed = 12
122
+ subpacket_count = int (stream [1 :12 ], 2 )
123
+ for _ in range (subpacket_count ):
124
+ packet , new_consumed = Packet .decode (stream [consumed :])
125
+ consumed += new_consumed
126
+ packets .append (packet )
127
+ else :
128
+ raise Exception (stream [0 ])
129
+
130
+ return OperatorPacket (version , PacketTypeId (type_id ), style , packets ), consumed
131
+
132
+ def __init__ (self , version : int , type_id : int , style : OperatorStyle , subpackets : List [Packet ]):
41
133
super ().__init__ (version , type_id )
134
+ self .style = style
42
135
self .subpackets = subpackets
43
136
44
137
def value (self ) -> int :
45
- if self .type_id == PacketTypeId .SUM . value :
138
+ if self .type_id == PacketTypeId .SUM :
46
139
return sum (p .value () for p in self .subpackets )
47
140
48
- elif self .type_id == PacketTypeId .PRODUCT . value :
141
+ elif self .type_id == PacketTypeId .PRODUCT :
49
142
return functools .reduce (operator .mul , [p .value () for p in self .subpackets ])
50
143
51
- elif self .type_id == PacketTypeId .MINIMUM . value :
144
+ elif self .type_id == PacketTypeId .MINIMUM :
52
145
return min (p .value () for p in self .subpackets )
53
146
54
- elif self .type_id == PacketTypeId .MAXIMUM . value :
147
+ elif self .type_id == PacketTypeId .MAXIMUM :
55
148
return max (p .value () for p in self .subpackets )
56
149
57
- elif self .type_id == PacketTypeId .GREATER_THAN . value :
150
+ elif self .type_id == PacketTypeId .GREATER_THAN :
58
151
if self .subpackets [0 ].value () > self .subpackets [1 ].value ():
59
152
return 1
60
153
else :
61
154
return 0
62
155
63
- elif self .type_id == PacketTypeId .LESS_THAN . value :
156
+ elif self .type_id == PacketTypeId .LESS_THAN :
64
157
if self .subpackets [0 ].value () < self .subpackets [1 ].value ():
65
158
return 1
66
159
else :
67
160
return 0
68
161
69
- elif self .type_id == PacketTypeId .EQUAL . value :
162
+ elif self .type_id == PacketTypeId .EQUAL :
70
163
if self .subpackets [0 ].value () == self .subpackets [1 ].value ():
71
164
return 1
72
165
else :
@@ -75,13 +168,27 @@ def value(self) -> int:
75
168
else :
76
169
raise Exception (self .type_id )
77
170
171
+ def encode_contents (self ) -> str :
172
+ out = ''
173
+ for packet in self .subpackets :
174
+ out += packet .encode ()
175
+
176
+ if self .style == OperatorStyle .LENGTH :
177
+ out = self .style .value + bin (len (out ))[2 :].zfill (15 ) + out
178
+
179
+ elif self .style == OperatorStyle .COUNT :
180
+ out = self .style .value + bin (len (self .subpackets ))[2 :].zfill (11 ) + out
181
+
182
+ return out
183
+
78
184
def __str__ (self ) -> str :
79
185
out = f'Operator<{ self .version } , { self .type_id } >\n '
80
186
81
187
for packet in self .subpackets :
82
- out += f'{ packet } \n '
188
+ for line in str (packet ).split ('\n ' ):
189
+ out += f' { line } \n '
83
190
84
- return out
191
+ return out [: - 1 ]
85
192
86
193
87
194
def read_input (textio : TextIO ) -> str :
@@ -93,69 +200,12 @@ def read_input(textio: TextIO) -> str:
93
200
out += bin (int (char , 16 ))[2 :].zfill (4 )
94
201
yield out
95
202
96
- def decode_packet_header (first ):
97
- return int (first [0 :3 ], 2 ), int (first [3 :6 ], 2 ), 6
98
-
99
- def decode_literal (version : int , stream : str ) -> LiteralPacket :
100
- offset = 0
101
- nibbles = []
102
- while True :
103
- is_last = stream [offset ] == '0'
104
- nibbles .append (int (stream [offset + 1 :offset + 5 ], 2 ))
105
- offset += 5
106
- if is_last :
107
- break
108
-
109
- literal = 0
110
- for i in range (len (nibbles )):
111
- literal += nibbles [i ]<< ((len (nibbles )- i - 1 )* 4 )
112
-
113
- # print(f'nibbles: {nibbles} = {literal}')
114
-
115
- return LiteralPacket (version , literal ), offset
116
-
117
- def decode_operator_packet (version : int , type_id : int , stream : str ) -> OperatorPacket :
118
- packets = []
119
- if stream [0 ] == '0' :
120
- consumed = 16
121
- subpacket_length = int (stream [1 :16 ], 2 )
122
- # print(f'op length: {subpacket_length}')
123
- while subpacket_length > 0 :
124
- packet , new_consumed = decode_packet (stream [consumed :])
125
- # print(packet, new_consumed)
126
- consumed += new_consumed
127
- subpacket_length -= new_consumed
128
- packets .append (packet )
129
- elif stream [0 ] == '1' :
130
- consumed = 12
131
- subpacket_count = int (stream [1 :12 ], 2 )
132
- # print(f'op count: {subpacket_count}')
133
- for _ in range (subpacket_count ):
134
- packet , new_consumed = decode_packet (stream [consumed :])
135
- consumed += new_consumed
136
- packets .append (packet )
137
- else :
138
- raise Exception (stream [0 ])
139
-
140
- return OperatorPacket (version , type_id , packets ), consumed
141
-
142
- def decode_packet (transmission : str ) -> Packet :
143
- consumed = 0
144
- version , type_id , new_consumed = decode_packet_header (transmission )
145
- consumed += new_consumed
146
- # print(version, type_id, consumed)
147
-
148
- if type_id == PacketTypeId .LITERAL .value :
149
- packet , new_consumed = decode_literal (version , transmission [consumed :])
150
- # print(f'read literal "{literal}" with {consumed} bits')
151
- consumed += new_consumed
152
- else :
153
- packet , new_consumed = decode_operator_packet (version , type_id , transmission [consumed :])
154
- consumed += new_consumed
155
-
156
- return packet , consumed
157
-
158
203
for transmission in read_input (sys .stdin ):
159
- print (transmission )
160
- packet , consumed = decode_packet (transmission )
161
- print (packet .value ())
204
+ # print(transmission, len(transmission))
205
+ packet , consumed = Packet .decode (transmission )
206
+
207
+ # encoded = packet.encode()
208
+ # encoded += '0' * (8-((len(encoded)) % 8))
209
+ # print(encoded, len(encoded))
210
+ print (packet .value ())
211
+ # print(packet)
0 commit comments