# Python program for OEIS A291365
# Michael S. Branicky, Sep 27 2021

# A291365		Number of closed Sturmian words of length n.
data = [2, 2, 4, 6, 12, 16, 26, 36, 52, 64, 86, 108, 142, 170, 206, 242]

from numba import njit

# (Python)
@njit
def hamwt(w, n):
    c1 = 0
    mask = 1
    for i in range(n):
        if mask & w:
            c1 += 1
        mask <<= 1
    return c1

@njit
def is_strm(w, n): # w is a finite Sturmian word
    for l in range(2, n):
        mask = (1 << l) - 1
        minc1 = n
        maxc1 = -1
        for j in range(n-l+1):
            subw = (mask & w) >> j
            newc1 = hamwt(subw, l)
            if newc1 > maxc1:
                maxc1 = newc1
            if newc1 < minc1:
                minc1 = newc1
            if maxc1 - minc1 > 1:
                return False
            mask <<= 1
    return True

@njit
def is_clsd(w, n): # is n-digit binary word closed
    if n == 1: return True
    for l in range(1, n):
        mask = (1 << l) - 1
        prefix, suffix = (w & (mask << (n-l))) >> (n-l), w & mask
        if prefix == suffix:
            internal = False
            for j in range(1, n-l):
                result = (suffix << j) ^ ((mask << j) & w)
                if result == 0:
                    internal = True
                    break
            if not internal:
                return True
    return False

@njit
def is_cs(w, n):   # n-digit word w is a closed Sturmian word
    return is_clsd(w, n) and is_strm(w, n)

@njit
def a(n):
    c = 0
    for b in range(2**(n-1), 2**n):
        c += is_cs(b, n)
    return 2*c

# TEST THE FUNCTION AGAINST THE DATA
print([a(n) for n in range(1, 17)]) # ~~~~
print(data)
print()

from time import time
time0 = time()

alst = []
for n in range(1, 101):
    an = a(n)
    alst.append(an)
    print(n, an, len(str(alst))-2, time()-time0, flush=True)
    print("   ", alst, flush=True)
    print("   ", data, flush=True)