from math import ceil from numpy import searchsorted # All unordered pairs of numbers (mod 2), devided to sets based on the value of a^3 + b^3 (mod 2). FIRST_BIT = [[(0, 0), (1, 1)], [(0, 1)]] def taxicab(N): """ Returns all taxicab numbers up to 'N'. The basic algorithm is to enumerate over all pairs of numbers a, b < cubic_root(N), keep a^3 + b^3 in a hash map, and return all values that corespond to more than one pair. This has time and spance complexity of O(N^(2/3)). We cannot do any better in terms of time - we have to go over all pairs of numbers up to N^(1/3). But we can for space - here we construct the pairs of numbers (a, b) "bit-by-bit". At each stage the numbers are defined (mod 2^bit) and the pairs are devided to sets based on the value of a^3 + b^3 (mod 2^bit). A taxicab number can only be defined by two pairs in the same set, which allows us to find taxicab numbers set-by-set (in a DFS-like fashion) without having to keep all sets in memory. Once we have constructed the numbers up to the desired bit, we find the taxinumbers in the usual way. Each set is, on average, twice the size as the one it was constructed from - there are 4 times as many pairs, but they are divided into 2 sets. We get space complexity of O(N^(1/3)). Best values for N are 2^(3k) - other values are waistful in terms of complexity. """ up_to_bit = n_bits(int(ceil(N ** (1./3))) - 1) result = [] for pair_set in FIRST_BIT: complete_from_set(pair_set, 2, up_to_bit, result) result.sort() # The algorithm gives us taxinumbers greater than N but they have gaps. return result[:searchsorted(result, N)] def complete_from_set(pair_set, bit, up_to_bit, result): """ Adds to 'result' all taxicab numbers that can be generated from 'pair_set'. """ if bit > up_to_bit: result.extend(all_taxicab_from_set(pair_set)) return for next_bit_set in add_bit_to_set(pair_set, bit): complete_from_set(next_bit_set, bit + 1, up_to_bit, result) def add_bit_to_set(pair_set, n): """ Gets a set of pairs of n-bit numbers (a, b) and returns all possible pairs of (n+1)-bit numbers with (a, b) as the n lower bits. The pairs are divided to sets based on the value of a^3 + b^3 (mod 2^(bit+1)). """ add = 1 << (n - 1) mask = (1 << n) - 1 cube_dict = {} for pair in pair_set: for next_pair in four_next_pairs(pair, add): n = cube_sum(*next_pair) & mask cube_dict[n] = cube_dict.get(n, []) + [next_pair] return cube_dict.itervalues() def four_next_pairs(pair, add): """ Returns the pairs made of 'add' added to x, y, neither, or both. """ x, y = pair # The result, and each pair, will be sorted by size. result = [(x, y), (x, y + add)] # This ensures we don't have repeating pairs. if x != y: result.append((y, x + add)) result.append((x + add, y + add)) return result def all_taxicab_from_set(pair_set): """ Returns all taxicab numbers defined by two pairs in pair_set. """ cube_sum_dict = {} for pair in pair_set: n = cube_sum(*pair) cube_sum_dict[n] = cube_sum_dict.get(n, []) + [pair] return (cube_sum for cube_sum, pairs in cube_sum_dict.iteritems() if len(pairs) > 1) def n_bits(n): """ Returns the number of bits of n. """ bit = 0 while n: bit += 1 n = n >> 1 return bit def cube_sum(a, b): return a ** 3 + b ** 3