Skip to content

Commit 4a4ea51

Browse files
committed
Add packet encoding
1 parent ed1ffa7 commit 4a4ea51

File tree

2 files changed

+147
-79
lines changed

2 files changed

+147
-79
lines changed

day_16/__main__.py

Lines changed: 128 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
import operator
77
from enum import Enum
8+
from utils import lookahead
89

910
class PacketTypeId(Enum):
1011
SUM = 0
@@ -16,57 +17,149 @@ class PacketTypeId(Enum):
1617
LESS_THAN = 6
1718
EQUAL = 7
1819

20+
class OperatorStyle(Enum):
21+
LENGTH = '0'
22+
COUNT = '1'
23+
1924
class Packet(ABC):
20-
def __init__(self, version: int, type_id: int):
25+
@classmethod
26+
def decode(cls, transmission: str) -> "Packet":
27+
version = int(transmission[0:3], 2)
28+
type_id = int(transmission[3:6], 2)
29+
consumed = 6
30+
31+
if type_id == PacketTypeId.LITERAL.value:
32+
packet, new_consumed = LiteralPacket.decode(version, transmission[consumed:])
33+
consumed += new_consumed
34+
else:
35+
packet, new_consumed = OperatorPacket.decode(version, type_id, transmission[consumed:])
36+
consumed += new_consumed
37+
38+
return packet, consumed
39+
40+
def __init__(self, version: int, type_id: PacketTypeId):
2141
self.version = version
2242
self.type_id = type_id
2343

2444
@abstractmethod
25-
def value(self):
45+
def value(self) -> int:
2646
pass
2747

48+
def encode(self) -> str:
49+
header = bin(self.version)[2:].zfill(3) + bin(self.type_id.value)[2:].zfill(3)
50+
contents = self.encode_contents()
51+
return header + contents
52+
53+
@abstractmethod
54+
def encode_contents(self) -> str:
55+
pass
56+
57+
2858
class LiteralPacket(Packet):
59+
@classmethod
60+
def decode(cls, version: int, stream: str) -> "LiteralPacket":
61+
offset = 0
62+
nibbles = []
63+
while True:
64+
is_last = stream[offset] == '0'
65+
nibbles.append(int(stream[offset+1:offset+5], 2))
66+
offset += 5
67+
if is_last:
68+
break
69+
70+
literal = 0
71+
for i in range(len(nibbles)):
72+
literal += nibbles[i]<<((len(nibbles)-i-1)*4)
73+
74+
return LiteralPacket(version, literal), offset
75+
2976
def __init__(self, version: int, literal: int):
30-
super().__init__(version, PacketTypeId.LITERAL.value)
77+
super().__init__(version, PacketTypeId.LITERAL)
3178
self.literal = literal
3279

3380
def value(self) -> int:
3481
return self.literal
3582

83+
def encode_contents(self) -> str:
84+
outs = []
85+
value = self.literal
86+
while value:
87+
outs.append(bin(value & 15)[2:].zfill(4))
88+
value = value >> 4
89+
90+
out = ''
91+
for digit, more in lookahead(outs[::-1]):
92+
if more:
93+
out += '1'
94+
else:
95+
out += '0'
96+
out += digit
97+
98+
return out
99+
36100
def __str__(self) -> str:
37101
return f'Literal<{self.version}, {self.type_id}, {self.literal}>'
38102

39103
class OperatorPacket(Packet):
40-
def __init__(self, version: int, type_id: int, subpackets: List[Packet]):
104+
105+
@classmethod
106+
def decode(cls, version: int, type_id: PacketTypeId, stream: str) -> "OperatorPacket":
107+
packets = []
108+
style = None
109+
if stream[0] == OperatorStyle.LENGTH.value:
110+
style = OperatorStyle.LENGTH
111+
consumed = 16
112+
subpacket_length = int(stream[1:16], 2)
113+
114+
while subpacket_length > 0:
115+
packet, new_consumed = Packet.decode(stream[consumed:])
116+
consumed += new_consumed
117+
subpacket_length -= new_consumed
118+
packets.append(packet)
119+
elif stream[0] == OperatorStyle.COUNT.value:
120+
style = OperatorStyle.COUNT
121+
consumed = 12
122+
subpacket_count = int(stream[1:12], 2)
123+
for _ in range(subpacket_count):
124+
packet, new_consumed = Packet.decode(stream[consumed:])
125+
consumed += new_consumed
126+
packets.append(packet)
127+
else:
128+
raise Exception(stream[0])
129+
130+
return OperatorPacket(version, PacketTypeId(type_id), style, packets), consumed
131+
132+
def __init__(self, version: int, type_id: int, style: OperatorStyle, subpackets: List[Packet]):
41133
super().__init__(version, type_id)
134+
self.style = style
42135
self.subpackets = subpackets
43136

44137
def value(self) -> int:
45-
if self.type_id == PacketTypeId.SUM.value:
138+
if self.type_id == PacketTypeId.SUM:
46139
return sum(p.value() for p in self.subpackets)
47140

48-
elif self.type_id == PacketTypeId.PRODUCT.value:
141+
elif self.type_id == PacketTypeId.PRODUCT:
49142
return functools.reduce(operator.mul, [p.value() for p in self.subpackets])
50143

51-
elif self.type_id == PacketTypeId.MINIMUM.value:
144+
elif self.type_id == PacketTypeId.MINIMUM:
52145
return min(p.value() for p in self.subpackets)
53146

54-
elif self.type_id == PacketTypeId.MAXIMUM.value:
147+
elif self.type_id == PacketTypeId.MAXIMUM:
55148
return max(p.value() for p in self.subpackets)
56149

57-
elif self.type_id == PacketTypeId.GREATER_THAN.value:
150+
elif self.type_id == PacketTypeId.GREATER_THAN:
58151
if self.subpackets[0].value() > self.subpackets[1].value():
59152
return 1
60153
else:
61154
return 0
62155

63-
elif self.type_id == PacketTypeId.LESS_THAN.value:
156+
elif self.type_id == PacketTypeId.LESS_THAN:
64157
if self.subpackets[0].value() < self.subpackets[1].value():
65158
return 1
66159
else:
67160
return 0
68161

69-
elif self.type_id == PacketTypeId.EQUAL.value:
162+
elif self.type_id == PacketTypeId.EQUAL:
70163
if self.subpackets[0].value() == self.subpackets[1].value():
71164
return 1
72165
else:
@@ -75,13 +168,27 @@ def value(self) -> int:
75168
else:
76169
raise Exception(self.type_id)
77170

171+
def encode_contents(self) -> str:
172+
out = ''
173+
for packet in self.subpackets:
174+
out += packet.encode()
175+
176+
if self.style == OperatorStyle.LENGTH:
177+
out = self.style.value + bin(len(out))[2:].zfill(15) + out
178+
179+
elif self.style == OperatorStyle.COUNT:
180+
out = self.style.value + bin(len(self.subpackets))[2:].zfill(11) + out
181+
182+
return out
183+
78184
def __str__(self) -> str:
79185
out = f'Operator<{self.version}, {self.type_id}>\n'
80186

81187
for packet in self.subpackets:
82-
out += f'{packet}\n'
188+
for line in str(packet).split('\n'):
189+
out += f' {line}\n'
83190

84-
return out
191+
return out[:-1]
85192

86193

87194
def read_input(textio: TextIO) -> str:
@@ -93,69 +200,12 @@ def read_input(textio: TextIO) -> str:
93200
out += bin(int(char, 16))[2:].zfill(4)
94201
yield out
95202

96-
def decode_packet_header(first):
97-
return int(first[0:3], 2), int(first[3:6], 2), 6
98-
99-
def decode_literal(version: int, stream: str) -> LiteralPacket:
100-
offset = 0
101-
nibbles = []
102-
while True:
103-
is_last = stream[offset] == '0'
104-
nibbles.append(int(stream[offset+1:offset+5], 2))
105-
offset += 5
106-
if is_last:
107-
break
108-
109-
literal = 0
110-
for i in range(len(nibbles)):
111-
literal += nibbles[i]<<((len(nibbles)-i-1)*4)
112-
113-
# print(f'nibbles: {nibbles} = {literal}')
114-
115-
return LiteralPacket(version, literal), offset
116-
117-
def decode_operator_packet(version: int, type_id: int, stream: str) -> OperatorPacket:
118-
packets = []
119-
if stream[0] == '0':
120-
consumed = 16
121-
subpacket_length = int(stream[1:16], 2)
122-
# print(f'op length: {subpacket_length}')
123-
while subpacket_length > 0:
124-
packet, new_consumed = decode_packet(stream[consumed:])
125-
# print(packet, new_consumed)
126-
consumed += new_consumed
127-
subpacket_length -= new_consumed
128-
packets.append(packet)
129-
elif stream[0] == '1':
130-
consumed = 12
131-
subpacket_count = int(stream[1:12], 2)
132-
# print(f'op count: {subpacket_count}')
133-
for _ in range(subpacket_count):
134-
packet, new_consumed = decode_packet(stream[consumed:])
135-
consumed += new_consumed
136-
packets.append(packet)
137-
else:
138-
raise Exception(stream[0])
139-
140-
return OperatorPacket(version, type_id, packets), consumed
141-
142-
def decode_packet(transmission: str) -> Packet:
143-
consumed = 0
144-
version, type_id, new_consumed = decode_packet_header(transmission)
145-
consumed += new_consumed
146-
# print(version, type_id, consumed)
147-
148-
if type_id == PacketTypeId.LITERAL.value:
149-
packet, new_consumed = decode_literal(version, transmission[consumed:])
150-
# print(f'read literal "{literal}" with {consumed} bits')
151-
consumed += new_consumed
152-
else:
153-
packet, new_consumed = decode_operator_packet(version, type_id, transmission[consumed:])
154-
consumed += new_consumed
155-
156-
return packet, consumed
157-
158203
for transmission in read_input(sys.stdin):
159-
print(transmission)
160-
packet, consumed = decode_packet(transmission)
161-
print(packet.value())
204+
# print(transmission, len(transmission))
205+
packet, consumed = Packet.decode(transmission)
206+
207+
# encoded = packet.encode()
208+
# encoded += '0' * (8-((len(encoded)) % 8))
209+
# print(encoded, len(encoded))
210+
print(packet.value())
211+
# print(packet)

utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,22 @@ def read_lines(textio: TextIO, spec: Dict[str, type]) -> List[Dict[str, Any]]:
99
for index, (piece_name, piece_type) in enumerate(spec.items()):
1010
blah[piece_name] = piece_type(pieces[index])
1111
output.append(blah)
12-
return output
12+
return output
13+
14+
def lookahead(iterable):
15+
"""Pass through all values from the given iterable, augmented by the
16+
information if there are more values to come after the current one
17+
(True), or if it is the last value (False).
18+
19+
https://stackoverflow.com/a/1630350/111777
20+
"""
21+
# Get an iterator and pull the first value.
22+
it = iter(iterable)
23+
last = next(it)
24+
# Run the iterator to exhaustion (starting from the second value).
25+
for val in it:
26+
# Report the *previous* value (more to come).
27+
yield last, True
28+
last = val
29+
# Report the last value.
30+
yield last, False

0 commit comments

Comments
 (0)