summaryrefslogtreecommitdiff
path: root/day14/day14.py
blob: ca35d2bc43d14b45f3da6b0f5a2c2ea6b4f714d5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Day 14."""
import re
from itertools import product
import sys

mask_pattern = r"mask = (\w+)"
mask_prog = re.compile(mask_pattern)
mem_pattern = r"mem\[(\d+)\] = (\d+)"
mem_prog = re.compile(mem_pattern)

with open("input") as f:
    data = f.read().rstrip().split("\n")

def to_bin_str(int_str):
    return bin(int(int_str))[2:].zfill(36)

def mask_bit_part_1(mask, val):
    if mask == 'X':
        return val
    return mask

def mask_bit_part_2(mask, val):
    if mask == '0':
        return val
    if mask == '1':
        return '1'
    return mask

memory = {}

for line in data:
    if line.startswith("mask"):
        mask = mask_prog.match(line).groups()[0]
    else:
        addr, val = mem_prog.match(line).groups()
        val = to_bin_str(val)
        res = ''.join([mask_bit_part_1(m, v) for m, v in zip(mask, val)])
        memory[addr] = int(res, 2)

solution_1 = sum(memory.values())
print(solution_1)
assert solution_1 == 5055782549997


# Day 2
# Little slow but works

memory = {}

for idx, line in enumerate(data):
    print(idx)
    if line.startswith("mask"):
        mask = mask_prog.match(line).groups()[0]
    else:
        orig_addr, val = mem_prog.match(line).groups()
        orig_addr = to_bin_str(orig_addr)
        orig_res = [mask_bit_part_2(m, v) for m, v in zip(mask, orig_addr)]
        addrs = []
        num_x = orig_res.count("X")
        perms = ["".join(seq) for seq in product("01", repeat=num_x)]
        all_addrs = []
        for perm in perms:
            res = orig_res.copy()
            for digit in perm:
                idx = res.index("X")
                res[idx] = f"{digit}"
            all_addrs.append(res)
        for addr in all_addrs:
            addr = int("".join(addr), 2)
            memory[addr] = int(val)

solution_2 = sum(memory.values())

# Day 2
print(solution_2)
assert solution_2 == 4795970362286