from __future__ import print_function
import copy
from itertools import combinations, product
import math
from pprint import pprint
import sys

'''Program for finding terms of A253426 (http://oeis.org/A253426).

Copyright (c) 2015 Robert C. Lyons

To find a solution (if any) for the specified number of bits and symbols:

    find_solution(num_bits, num_symbols)
        Return a dictionary containing a solution if one exists; otherwise return {}.

    >>> import A253426
    >>> A253426.find_solution( 3, 4 ) == {'A': ['0000', '1111'], 'C': ['0010', '1101'], 'B': ['0001', '1110'], 'E': ['1000', '0111'], 'D': ['0100', '1011']}
    True

Please note that find_solution(num_bits, num_symbols) does an exhaustive search and can take 
a very long time to complete for num_bits >= 5. For example, find_solution( 5, 8 ) would run 
for a VERY long time. On the other hand, find_solution( 5, 9 ) returns relatively quickly
(i.e., within a few minutes), since there are no solutions and the method is able to 
eliminate a large number of the potential solutions.
'''

def get_encoding_transitions( num_encodings, num_bits ):
    format_string = '{0:0' + str( num_bits ) + 'b}'
    transitions = {}
    for i in xrange(0, num_encodings):
        encoding = format_string.format(i)
        transition_set = set()
        transitions[ encoding ] = transition_set
        for j in xrange(0, num_encodings):
            if ( i != j and i | j == j ):
                transition_set.add( format_string.format(j) )
    return transitions

def get_possible_transition_encodings( transitions, encodings_for_encoding_scheme_1, num_symbols ):
    transition_encodings = set()
    for encoding in encodings_for_encoding_scheme_1:
        if ( len( transitions[ encoding ] ) < num_symbols-1 ):
            return []
        transition_encodings |= transitions[ encoding ]
    transition_encodings -= set( encodings_for_encoding_scheme_1 )
    return list( transition_encodings )

def get_encoding_scheme_1( encodings_for_encoding_scheme_1, symbols ):
    symbol_index = 0
    encoding_scheme = {}
    for encoding in encodings_for_encoding_scheme_1:
        encoding_scheme[ symbols[ symbol_index ] ] = [ encoding ]
        symbol_index = symbol_index + 1
    return encoding_scheme

def get_encoding_scheme( encoding_scheme_1, symbol_seq_for_encoding_scheme_2, encodings_for_encoding_scheme_2, num_encodings ):
    first_symbol = convert_to_letter( 0 )
    encoding_scheme = encoding_scheme_1
    for i in xrange( 0, len( encodings_for_encoding_scheme_2 ) ):
        symbol = symbol_seq_for_encoding_scheme_2[i]
        # Skip symbol if it's blank.
        if ( symbol >= first_symbol ):
            encoding = encodings_for_encoding_scheme_2[i]
            encoding_scheme[ symbol ].append( encoding )
    return encoding_scheme

def is_solution( encoding_scheme, transitions ):
    encoding_scheme_as_sets = {}
    for symbol in encoding_scheme:
        encodings_as_set = set( encoding_scheme[ symbol ] )
        encoding_scheme_as_sets[ symbol ] = encodings_as_set

    for symbol_1 in encoding_scheme:
        first_encoding_for_symbol_1 = encoding_scheme[ symbol_1 ][0]
        transitions_for_symbol_1 = transitions[ first_encoding_for_symbol_1 ]
        for symbol_2 in encoding_scheme_as_sets:
            if ( symbol_1 == symbol_2 ):
                continue
            if ( not ( bool( encoding_scheme_as_sets[ symbol_2 ] & transitions_for_symbol_1 ) ) ):
                return False
    return True

def get_symbols( num_symbols ):
    symbols = []
    for symbol_index in xrange( 0, num_symbols ):
        symbols.append( convert_to_letter( symbol_index ) )
    return symbols

def convert_to_letter( symbol_index ):
    return chr( symbol_index + ord('A') )

def find_solution( num_bits, num_symbols ):
    """Find a solution (if any) for the specified number of bits and number of symbols.

    >>> find_solution(1, 1) == {'A': ['0', '1']}
    True
    >>> find_solution(1, 2) == {}
    True
    >>> find_solution(2, 2) == {'A': ['00', '11'], 'B': ['01', '10']}
    True
    >>> find_solution(2, 3) == {}
    True
    >>> find_solution(3, 4) == {'A': ['000', '111'], 'B': ['001', '110'], 'C': ['010', '101'], 'D': ['100', '011']}
    True
    >>> find_solution(3, 5) == {}
    True
    """

    num_encodings = 2**num_bits

    symbols = get_symbols( num_symbols )
    # Note: We're relying on the fact that " " is less than "A".
    symbols_and_blank = symbols[::-1] + [ " " ]

    transitions = get_encoding_transitions( num_encodings, num_bits )
    # As an optimization, we sort the encodings by the number of 1 bits.
    encodings = sorted( transitions.keys(), key=lambda encoding: ( encoding.count("1"), encoding ) )

    count = 0
    for encodings_for_encoding_scheme_1 in combinations(encodings, num_symbols):
        count = count + 1

        encoding_scheme_1 = get_encoding_scheme_1( encodings_for_encoding_scheme_1, symbols )

        possible_encodings_for_encoding_scheme_2 = get_possible_transition_encodings( transitions, encodings_for_encoding_scheme_1, num_symbols )

        if ( len( possible_encodings_for_encoding_scheme_2 ) < num_symbols ):
            continue

        # As an optimization, we sort the possible encodings for encoding scheme 2 by the number of 1 bits and then reverse it.
        possible_encodings_for_encoding_scheme_2 = sorted( possible_encodings_for_encoding_scheme_2, key=lambda encoding: ( encoding.count("1"), encoding ) )[::-1]

        if ( count % 1000000 == 0 ):
            print("Iteration: "+ str(count) + ", encodings_for_encoding_scheme_1: " + str(encodings_for_encoding_scheme_1) + ", encoding_scheme_1: " + str(encoding_scheme_1) + ", possible_encodings_for_encoding_scheme_2: " + str(possible_encodings_for_encoding_scheme_2), file=sys.stderr )

        for encodings_for_encoding_scheme_2 in combinations(possible_encodings_for_encoding_scheme_2, num_symbols):
            encodings_for_encoding_scheme_2_len = len( encodings_for_encoding_scheme_2 )
            for symbol_seq_for_encoding_scheme_2 in product(symbols_and_blank, repeat=encodings_for_encoding_scheme_2_len):
                if ( len( set( symbol_seq_for_encoding_scheme_2 ) - set( " " ) ) == num_symbols ):
                    encoding_scheme = get_encoding_scheme( copy.deepcopy( encoding_scheme_1 ), symbol_seq_for_encoding_scheme_2, encodings_for_encoding_scheme_2, num_encodings )
                    if ( is_solution( encoding_scheme, transitions ) ):
                        return encoding_scheme

    return {}

if __name__ == "__main__":
    import doctest
    doctest.testmod()

    num_bits = int( sys.argv[1] )
    num_symbols = int( sys.argv[2] )

    print( "Solution for " + str(num_bits) + " bits, " + str(num_symbols) + " symbols:" )
    pprint( find_solution( num_bits, num_symbols ) )