# Python program for OEIS A079874
# Michael S. Branicky, Feb 01 2023

# Computes complete reachable set, THEN
# partititions that set by # of lights on

# A079873		Number of positions that are exactly n moves from the starting position in the Classic Lights Out puzzle.		+60
rdata = [1, 25, 300, 2300, 12650, 53130, 176176, 476104, 982335, 1596279, 1935294, 1684446, 1004934, 383670, 82614, 7350]

# A079874		Number of configurations in the Classic Lights Out puzzle with n lights on.		1
ondata = [1, 5, 72, 576, 3166, 13326, 44256, 119944, 270531, 511319, 816752, 1113504, 1300852, 1300852, 1113504, 816752, 511319, 270531, 119944, 44256, 13326, 3166, 576, 72, 5, 1]

from time import time
time0 = time()

from collections import Counter

def moves(p, shape, states):
    nxt, (n, m), k = set(), shape, states
    for r in range(n):
        for c in range(m):
            new = list(p[:])
            for ro, co in [(r, c), (r+1, c), (r, c+1), (r-1, c), (r, c-1)]:
                if 0 <= ro < n and 0 <= co < m:
                    new[ro*m + co] = (new[ro*m + co]+1)%k
            nxt.add(tuple(new))
    return nxt

def reachset(start, shape, states, v=False, maxd=float('inf')):
    alst, d, expanded, frontier = [], 0, set(), {start}
    alst.append(len(frontier))
    if v: print(len(frontier), end=", ")
    while len(frontier) > 0 and d < maxd:
        reach1 = set(m for p in frontier for m in moves(p, shape, states) if m not in expanded)
        expanded |= frontier
        if len(reach1):
            alst.append(len(reach1))
            if v: print(len(reach1), end=", ")
        frontier = reach1
        d += 1
    return expanded

def lightson(q, shape):
    n, m = shape
    return sum(1 for r in range(n) for c in range(m) if q[r*m + c])

shape, states = (5, 5), 2  # 5x5 with on-off states
start = tuple([0 for i in range(shape[0]*shape[1])])

print(rdata)
print(" ", end="")
reach = reachset(start, shape, states, v=True)
print("\nElapsed", time()-time0) # takes about 14-15 minutes in Google Colab

c = Counter(lightson(q, shape) for q in reach)
print([c[n] for n in range(max(c)+1)]) # ~~~~
print(ondata)

# OUTPUTS
"""
[1, 25, 300, 2300, 12650, 53130, 176176, 476104, 982335, 1596279, 1935294, 1684446, 1004934, 383670, 82614, 7350]
 1, 25, 300, 2300, 12650, 53130, 176176, 467104, 982335, 1596279, 1935294, 1684446, 1004934, 383670, 82614, 7350, Elapsed 857.926066160202
[1, 5, 72, 576, 3166, 13326, 44256, 119944, 270531, 511319, 816752, 1113504, 1300852, 1300852, 1113504, 816752, 511319, 270531, 119944, 44256, 13326, 3166, 576, 72, 5, 1]
[1, 5, 72, 576, 3166, 13326, 44256, 119944, 270531, 511319, 816752, 1113504, 1300852, 1300852, 1113504, 816752, 511319, 270531, 119944, 44256, 13326, 3166, 576, 72, 5, 1]
"""