// Markus Sigg 2023-04-06
//
// With BITWIDTH = 64 and MAX_GIGABYTES = 1024, a(0),...,a(38)
// can be calculated. Beware: This needs up to 965 GB memory.
// With a Xeon 5675, 96 GB RAM and a 900 GB swap file on an SSD,
// the program takes ~40 hours when compiled with gcc -OFAST.

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

#define BITWIDTH 64
#define MAX_GIGABYTES 8

#if   BITWIDTH == 32
typedef __uint32_t Natural;
#elif BITWIDTH == 64
typedef __uint64_t Natural;
#elif BITWIDTH == 128
typedef __uint128_t Natural;
#endif

void print(const char *fmt, ...)
{
	time_t t = time(NULL);
	char *ct = ctime(&t);

	ct[strlen(ct)-1] = 0;
	printf("%s: ", ct);

	va_list args;
	va_start(args, fmt);
	vfprintf(stdout, fmt, args);
	va_end(args);

	fflush(stdout);
}

const Natural naturalOne = 1;
const Natural naturalMax = (naturalOne << (BITWIDTH - 1)) - 1;

typedef struct __attribute__((__packed__)) {
	// represents numerator / 2^denominator
	Natural   numerator;
	__uint8_t denominator;
} Rational;

void reduce(Rational *r)
{
	if (r->numerator == 0) {
		print("Error: Numerator = 0 in reduce().\n");
		exit(-1);
	}

	while ((r->numerator & 1) == 0 && r->denominator > 0) {
		r->numerator /= 2;
		r->denominator--;
	}
}

int compareRationals(Rational r1, Rational r2)
{
	__uint8_t shift = r1.denominator < r2.denominator ? r1.denominator : r2.denominator;
	Natural n1 = r1.numerator;
	Natural d1 = naturalOne << (r1.denominator - shift);
	Natural n2 = r2.numerator;
	Natural d2 = naturalOne << (r2.denominator - shift);

	if (n1 > naturalMax / d2 || n2 > naturalMax / d1) {
		print("Error: Overflow in compareRationals().\n");
		exit(-1);
	}

	Natural e = n1 * d2;
	Natural f = n2 * d1;

	if      (e < f) return -1;
	else if (e > f) return  1;
	else            return  0;
}

typedef struct {
	// represents ax + b
	Rational a, b;
} RationalPair;

int comparePairs(RationalPair p1, RationalPair p2)
{
	int cmp = compareRationals(p1.a, p2.a);

	if      (cmp == -1) return -1;
	else if (cmp ==  1) return  1;
	else                return compareRationals(p1.b, p2.b);
}

// We operate with sorted arrays of RationalPairs. As the functions
// f:(a,b) -> (a/2,b/2) and g:(a,b) -> (3a,3b+1) preserve the order
// given by comparePairs(), the new array can be created by merging
// the array of elements f(a,b) with the array of elements g(a,b),
// dropping duplicates on the move.

typedef struct {
	size_t count;
	RationalPair *pairs;
} RationalPairsArray;

void iterate(RationalPairsArray *array)
{
	size_t count = array->count;
	size_t size = count * sizeof(RationalPair);

	if (2 * size > (size_t) MAX_GIGABYTES * 1024 * 1024 * 1024) {
		print("Memory limit reached.\n");
		exit(0);
	}

	array->count = 0;
	array->pairs = realloc(array->pairs, 2 * size);

	if (array->pairs == NULL) {
		print("Error: Failed to allocate %lu bytes.\n", 2 * size);
		exit(-1);
	}

	memmove(&array->pairs[count], array->pairs, size);

	RationalPair *pairs = &array->pairs[count]; // old values

	// We process old values with indexes i1,i2 and write
	// the new values starting at the beginning of the array.

	size_t i1 = 0; // for new value n/2, stored in p1
	size_t i2 = 0; // for new value 3n+1, stored in p2

	while (i1 < count || i2 < count) {
		RationalPair p1, p2;

		if (i1 < count) { // (ax+b) / 2 = (a/2)x + b/2
			Rational a = pairs[i1].a;
			Rational b = pairs[i1].b;

			p1 = (RationalPair) {
				{ a.numerator, a.denominator + 1 },
				{ b.numerator, b.denominator + 1 }
			};
		}

		if (i2 < count) { // 3(ax+b) + 1 = (3a)x + (3b+1)
			Rational a = pairs[i2].a;
			Rational b = pairs[i2].b;

			if (a.numerator > naturalMax / 3) {
				print("Error: Overflow 1 in iterate().\n");
				exit(-1);
			};

			if (b.numerator > (naturalMax - (naturalOne << b.denominator)) / 3) {
				print("Error: Overflow 2 in iterate().\n");
				exit(-1);
			};

			p2 = (RationalPair) {
				{ 3 * a.numerator, a.denominator },
				{ 3 * b.numerator + (naturalOne << b.denominator), b.denominator }
			};

			reduce(&p2.b);
		}

		RationalPair p;

		if (i1 == count) {
			p = p2;
			i2++;
		} else if (i2 == count) {
			p = p1;
			i1++;
		} else {
			int cmp = comparePairs(p1, p2);

			if (cmp == -1) {
				p = p1;
				i1++;
			} else if (cmp == 1) {
				p = p2;
				i2++;
			} else {
				p = p1;
				i1++;
				i2++;
			}
		}

		array->pairs[array->count++] = p;
	}
}

int main(int argc, char **argv)
{
	print("Will use up to %d GB memory.\n", MAX_GIGABYTES);
	print("Size of Natural     : %lu\n", (unsigned long) sizeof(Natural));
	print("Size of Rational    : %lu\n", (unsigned long) sizeof(Rational));
	print("Size of RationalPair: %lu\n", (unsigned long) sizeof(RationalPair));

	if (8 * sizeof(Natural) != BITWIDTH) {
		print("Error: Wrong size of Natural.\n");
		exit(-1);
	}

	RationalPairsArray array;

	// start with element x = (1/2^0)x + 0/2^0

	array.count = 1;
	array.pairs = malloc(sizeof(RationalPair));
	array.pairs[0] = (RationalPair) {{ 1, 0 }, { 0, 0 }};

	for (int i = 0; ; i++) {
		print("a(%2d) = %lu\n", i, (unsigned long) array.count);
		iterate(&array);
	}
}