/* Kevin Ryde, June 2024

   Usage: ./a.out [-v] [n] [n-n]

   This is some C code using GNU MP, and best in a 64 bit build,
   to calculate terms of

       A373727(n) = largest sum of digits of a cube
         which is n digits long

   The method is a brute force search through all cubes in decimal
   digits stepped by differences and looking at sum of digits.

   This is feasible though to n=36 in about 3 hours on a 3GHz CPU.
   The number of cubes grows as roughly (10^(1/3))^n so each n
   takes roughly 2.26x longer.


   Command Line
   ------------

   The calculation is made for individual n or ranges of n on the
   command line.  Output is a free-form report

       ./a.out 6
       =>
       n=6 max sum 44

   Optional command line parameter "-v" prints more information,
   or -v -v even more.


   Compile Time
   ------------

   Cubes are calculated in WIDTH many words of type unsigned long.
   Each word holds 12 digits in a 64 bit build, or 6 in 32 bits.
   Set WIDTH according to a desired n or maximum n.

   WIDTH is a compile time constant since that allows the compiler
   to unroll loops over WIDTH many words.

   The supplied WIDTH = 3 allows n <= 36 digits in a 64 bit build.
   Decrease it for more speed on smaller n.
   Increase it for a 32 bit build or to go beyond n=36.


   Implementation Notes
   --------------------

   See "Fivedec Decimal Digits" for notes on how digits are kept
   as 5 bits each in words, and how to add those words.

   Cubes are stepped by adding differences

       cube += d1
       d1 += d2
       d2 += 6

   These are

       d1(x) = (x+1)^3 - x^3    = 3*x^2 + 3*x + 1
       d2(x) = d1(x+1) - d1(x)  = 6*x + 6
               d2(x+1) - d2(x)  = 6

   "cube" starts as the largest x^3 < n digits long and the first
   step goes to the first cube of n digits.  Sum of digits is
   calculated during the step operation, hence taking one step to
   the first desired power.

   "d1 += d2" is after d1 has been used on the cube.  That allows
   the compiler some instruction level parallelism stepping d1 at
   the same time as sum of digits of the cube (and other work).

   GNU MP is used to calculate initial x and consequent cube,d1,d2.
   and num_cubes, but is not used within the search.

   Counting down num_cubes is easier than watching for "cube"
   exceeding n digits.  num_cubes is uint64_t in all builds since
   a count bigger than 32 bits is well within reach.
*/


#define WIDTH  3   /* many WORDs to store each cube */


/* Set WANT_ASSERT to 1 to enable self-checks (slow).
   Set WANT_DEBUG to 1 for some rough development prints.
*/
#define WANT_ASSERT    1
#define WANT_DEBUG     0

#if ! WANT_ASSERT
#define NDEBUG
#endif

#include <assert.h>
#include <errno.h>
#include <inttypes.h>
#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdnoreturn.h>
#include <time.h>

#include <gmp.h>


/* ------------------------------------------------------------------------ */
/* Generic */

#if WANT_DEBUG
#define DEBUG(expr)  do { expr; } while (0)
#else
#define DEBUG(expr)  do {       } while (0)
#endif

#ifdef __GNUC__
#define LIKELY(cond)    __builtin_expect((cond) != 0, 1)
#define UNLIKELY(cond)  __builtin_expect((cond) != 0, 0)
#define ATTRIBUTE_PRINTF __attribute__ ((format (printf,1,2)))
#else
#define LIKELY(cond)    (cond)
#define UNLIKELY(cond)  (cond)
#define ATTRIBUTE_PRINTF
#endif

#define numberof(array)  (sizeof(array)/sizeof((array)[0]))
#define MIN(x,y)  ((x)<=(y) ? (x) : (y))

/* evaluate to ceil(n/d), with no risk of overflow */
#define DIV_CEIL(n,d)  (((n)/(d)) + (((n)%(d)) != 0))

static noreturn ATTRIBUTE_PRINTF void
error (const char *format, ...)
{
  va_list ap;
  va_start (ap, format);
  vfprintf (stderr, format, ap);
  va_end (ap);
  exit(1);
}

/* CPU time consumed by this process so far, in seconds */
static double
cputime (void)
{
  struct timespec t;
  if (clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &t) != 0)
    error("Cannot clock_gettime() CPUTIME: %s\n",
          strerror(errno));
  return t.tv_sec + t.tv_nsec/1e9;
}

/* Return the value from z as a uint64_t,
   or error out if too big. */
static uint64_t
z_get_uint64 (const mpz_t z)
{
  if (mpz_sizeinbase(z,2) > 64)
    error("z_get_uint64() z too big\n");
  if (sizeof(unsigned long) >= sizeof(uint64_t)) {
    return (mpz_get_ui(z));
  } else {
    /* mpz_export() is possible too, if rely on normal native
       endian-ness for uint64_t in memory.  But this conversion
       is not time critical.  */
    uint64_t ret = 0;
    int i;
    for (i=0; i < 64; i++)
      ret += (uint64_t)mpz_tstbit(z,i) << i;
    return (ret);
  }
}

/* Return the number of cubes of n digits.
   This is A181354 for n>=2, but at n=1 here 0 is taken as no digits
   but A181354 reckons 0 as one digit. */
uint64_t
num_cubes_n_digits (int n)
{
  if (! (n>=1)) error("num_cubes_n_digits() is for n>=1\n");
  mpz_t hi, lo;
  mpz_inits (hi, lo, NULL);

  /* hi = floor(cube_root(10^n-1)),
     cube root of largest cube of n digits */
  mpz_ui_pow_ui (hi, 10,n);
  mpz_sub_ui (hi, hi,1);
  mpz_root (hi, hi,3);

  /* lo = floor(cube_root(10^(n-1)-1)),
     cube root of largest cube of n-1 digits */
  mpz_ui_pow_ui (lo, 10,n-1);
  mpz_sub_ui (lo, lo,1);
  mpz_root (lo, lo,3);

  mpz_sub (hi, hi,lo);
  uint64_t num_cubes = z_get_uint64 (hi);

  mpz_clears (hi, lo, NULL);
  return (num_cubes);
}

/* ------------------------------------------------------------------------ */
/* Fivedec Decimal Digits.

   Decimal digits are held in 5 bits each within a WORD of type
   unsigned long.  WORD is expected to be the native CPU word size.
   At least one bit is spare above the 5s, ready to receive a carry.
   For example, a 64 bit WORD

       high                       low
       carry  five five ... five five
       \---/  \---------------------/
       4 bits      12*5 = 60 bits

   When adding two decimal digits x+y, an extra offset x+22 turns
   x=0..9 into x=22..31 and the latter is ready to have any carry
   from x+y+22 propagate up to the next 5 bits by ordinary binary
   addition.

   After such an offset addition, each 5 bit digit is

       0..9     if wrap-around (a carry propagated up)
       22..31   if not

   The 22..31 must drop back down to 0..9.  This is indicated by
   the "16" bit (high bit of the 5).  fivedec_adjust() clears that
   16 and shifts that 16 to become possible further 6 to subtract.

   In fivedec_vector functions, and the various cube[] etc which
   are such vectors, vec[0] is the least significant WORD.
*/

typedef unsigned long WORD;

#define WORD_BITS            ((int) (CHAR_BIT * sizeof(WORD)))
#define WORD_MAX             ((WORD) (~ (WORD) 0))
#define FIVEDEC_NUM_DIGITS   ((WORD_BITS - 1) / 5)
#define FIVEDEC_TOTAL_BITS   (5*FIVEDEC_NUM_DIGITS)
#define FIVEDEC_ABOVE_BITS   (WORD_BITS - FIVEDEC_TOTAL_BITS)

#define FIVEDEC_TOTAL_MASK   (WORD_MAX >> FIVEDEC_ABOVE_BITS)
#define FIVEDEC_ABOVE_MASK   (~ FIVEDEC_TOTAL_MASK)
#define FIVEDEC_ONES         (FIVEDEC_TOTAL_MASK / 31)
#define FIVEDEC_ADJUST_MASK  (16 * FIVEDEC_ONES)
#define FIVEDEC_ADD_OFFSET   (22 * FIVEDEC_ONES)

#define FIVEDEC_1_AT_POS(pos)  ((WORD)1 << (5*(pos)))
#define FIVEDEC_MASK_BELOW_POS(pos)  (FIVEDEC_1_AT_POS(pos) - 1)

/* Check that all digits in x are in the range 0..9,
   and 0 bits above the FIVEDEC_TOTAL_BITS of them (no carry left).
*/
int
fivedec_validate (WORD x)
{
  int i;
  for (i=0; i < FIVEDEC_NUM_DIGITS; i++) {
    int d = (x >> (5*i)) & 31;
    if (! (0 <= d && d <= 9)) {
      printf("oops, i=%d digit %d out of range\n", i,d);
      return (0);
    }
  }
  WORD above = x >> FIVEDEC_TOTAL_BITS;
  if (above) {
    printf("oops, above bits %d\n", (int) above);
    return (0);
  }
  return (1);
}

int
fivedec_vector_validate (const WORD x[], int len)
{
  assert (len >= 0);
  int ret = 1;
  int i;
  for (i=0; i < len; i++)
    if (! fivedec_validate (x[i]))  ret = 0;
  return (ret);
}

/* Return the carry in x after an addition.
   This is 0 or 1.
   x doesn't have to be adjusted yet.  */
static inline WORD
fivedec_carry (WORD x)
{
  WORD c = x >> FIVEDEC_TOTAL_BITS;
  assert (c==0 || c==1);
  return (c);
}

/* Adjust digits of x back into 0..9 range after an addition.
   Clear any carry in x too.  */
static inline WORD
fivedec_adjust_and_clear_carry (WORD x)
{
  WORD flags = x & FIVEDEC_ADJUST_MASK;
  x &= FIVEDEC_TOTAL_MASK - FIVEDEC_ADJUST_MASK;
  return (x - 3*(flags >> 3));
}

/* Set *dst = x + y + carry_in and return carry out 0 or 1. */
static inline WORD
fivedec_addc (WORD *dst, WORD x, WORD y, WORD carry_in)
{
  assert (fivedec_validate(x));
  assert (fivedec_validate(y));
  assert (carry_in==0 || carry_in==1);

  WORD s = (x + y + carry_in + FIVEDEC_ADD_OFFSET);
  *dst = fivedec_adjust_and_clear_carry(s);
  return (fivedec_carry(s));
}

/* Add vectors of decimal digits x[0..len-1] + y[0..len-1] + c.
   c is a carry in, either 0 or 1.
   Store result in dst[0..len-1] and return carry out. */
static inline WORD
fivedec_vector_addc (WORD *dst, const WORD *x, const WORD *y, int len, WORD c)
{
  assert (fivedec_vector_validate(x,len));
  assert (fivedec_vector_validate(y,len));
  assert (len >= 0);

  int i;
  for (i=0; i < len; i++)
    c = fivedec_addc (&dst[i], x[i],y[i], c);
  return (c);
}

/* x += add for a single fivedec word "add".
   Return any carry (0 or 1).  */
static inline WORD
fivedec_vector_addc_single (WORD x[], int len, WORD add)
{
  assert (fivedec_vector_validate(x,len));
  assert (fivedec_validate(add));
  assert (len >= 0);

  int i;
  for (i=0; i < len; i++) {
    add = fivedec_addc (&x[i], x[i], add, 0);
    if (LIKELY (add==0)) break;
  }
  return (add);
}

/* Return digit at position i in x[],
   where i=0 is the least significant digit.
   "len" is the number of words available in x.  */
static WORD
fivedec_vector_get_digit (const WORD x[], int len, int pos)
{
  assert (pos >= 0);
  if (pos >= len*FIVEDEC_NUM_DIGITS)
    error ("fivedec_vector_get_digit() beyond end of x[]\n");
  int wpos = pos / FIVEDEC_NUM_DIGITS;
  int dpos = pos % FIVEDEC_NUM_DIGITS;;
  return ((x[wpos] >> (5*dpos)) & 31);
}

void
fivedec_fprint_digit (FILE *fp, int d)
{
  if (d <= 9)
    fprintf(fp,"%d", d);
  else
    fprintf(fp,"[%d]", d);
}
void
fivedec_vector_fprint (FILE *fp, const WORD x[], int len)
{
  int i, nz=0;
  for (i = len*FIVEDEC_NUM_DIGITS - 1; i>=0; i--) {
    int digit = fivedec_vector_get_digit(x,len,i);
    if (!nz && i!=0 && digit==0) continue; /* skip leading 0s */
    fivedec_fprint_digit(fp,digit);
    nz = 1;
  }
  fprintf(fp,"\n");
}
void
fivedec_vector_print (const WORD x[], int len)
{
  fivedec_vector_fprint(stdout,x,len);
}

static inline void
fivedec_vector_copy (WORD *dst, const WORD *src, int len)
{
  memcpy (dst, src, len*sizeof(WORD));
}
static inline void
fivedec_vector_clear (WORD *ptr, int len)
{
  memset (ptr, '\0', len*sizeof(WORD));
}

/* Return the digit length of x, ignoring any high 0 digits. */
int
fivedec_vector_digits_length (const WORD *x, int len)
{
  int i;
  for (i=len*FIVEDEC_NUM_DIGITS-1; i>=0; i--)
    if (fivedec_vector_get_digit(x,len, i))
      return (i+1);
  return (0);
}

/* even num digits    o e o e
   odd  num digits  e o e o e
*/
#define FIVEDEC_EVENS_NUM_DIGITS  ((FIVEDEC_NUM_DIGITS+1) / 2)
#define FIVEDEC_ODDS_NUM_DIGITS   (FIVEDEC_NUM_DIGITS - FIVEDEC_EVENS_NUM_DIGITS)

/* 1 at each odd position digit */
#define FIVEDEC_ODDS_ONES                               \
  ( (FIVEDEC_MASK_BELOW_POS(2*FIVEDEC_ODDS_NUM_DIGITS)  \
     / FIVEDEC_MASK_BELOW_POS(2)) << 5)

/* 1 at each even position digit */
#define FIVEDEC_EVENS_ONES \
  (((FIVEDEC_ODDS_ONES << 5) | 1) & FIVEDEC_TOTAL_MASK)
/* mask 31 at each even position */
#define FIVEDEC_EVENS_MASK  (31 * FIVEDEC_EVENS_ONES)

static inline WORD
fivedec_sumdigits (WORD x)
{
  /* Doubling approach where first a shift adds each odd position
     digit to the even position digit below it.
     Then add each of the pairs at 2 mod 4 position to the 0 mod 4
     below, and so on.

     Two 5 bit places is maximum value 1023 for a sum, which
     allows WORD to have 113 digits all 9s, which is bigger than
     12 in 64 bit WORD, or 25 in 128 bit WORD.
  */
  assert (fivedec_validate(x));

  if (9*FIVEDEC_NUM_DIGITS >= FIVEDEC_MASK_BELOW_POS(2))
    error("oops, sumdigits might exceed 10 bits\n");

  assert (2*9 <= 31);
  x += (x >> 5);
  x &= FIVEDEC_EVENS_MASK;
  if (FIVEDEC_NUM_DIGITS > 2)
    x += (x >> (2*5));
  if (FIVEDEC_NUM_DIGITS > 4)
    x += (x >> (4*5));
  if (FIVEDEC_NUM_DIGITS > 8)
    x += (x >> (WORD_BITS>40 ? 8*5 : 0));  /* quieten compiler */
  if (FIVEDEC_NUM_DIGITS > 16)
    x += (x >> (WORD_BITS>80 ? 16*5 : 0));  /* quieten compiler */
  if (FIVEDEC_NUM_DIGITS > 32)
    error("oops, sumdigits doesn't expect FIVEDEC_NUM_DIGITS > 32\n");
  return (x & FIVEDEC_MASK_BELOW_POS(2));
}

/* Set dst[0..len-1] to decimal digits from src.
   The value in src is changed.
   If src is too big then error() out. */
void
fivedec_vector_from_mpz (WORD *dst, int len, mpz_t src)
{
  int i, pos;
  assert (len >= 0);
  for (i=0; i<len; i++) {
    WORD w = 0;
    for (pos=0; pos < FIVEDEC_NUM_DIGITS; pos++)
      w += ((WORD) mpz_fdiv_q_ui (src, src,10)) << (pos*5);
    dst[i] = w;
  }
  assert (fivedec_vector_validate(dst,len));
  if (mpz_sgn(src) != 0)
    error("fivedec_vector_from_mpz() src doesn't fit in %d WORDs\n", len);
}


/* ------------------------------------------------------------------------ */

static int option_verbose = 0;

void
check_A373727_max_sum (int n, int sum)
{
  static const int want[] = {
    0,  /* no n=0 */
    8, 10, 18, 28, 28, 44, 46, 54, 63, 73,
    80, 82, 98, 100, 109, 118, 125, 136, 144, 154,
    154, 163, 172, 181, 190, 190, 199, 208, 217, 226,   /* n=30 */
    235,   /* n=31 */
    243,   /* n=32 */
    253,   /* n=33 */
    260,   /* n=34 */
    262,   /* n=35 */
    278    /* n=36 */
  };
  if (! (n >= 1)) error("check_A373727_max_sum() is for n>=1\n");
  if (n >= numberof(want)) {
    if (option_verbose) printf("  (beyond A373727 data)\n");
    return;
  }
  if (sum == want[n]) {
    if (option_verbose) printf("  (good vs A373727 data)\n");
    return;
  }
  printf("oops n=%d max sum\n", n);
  printf("  want %d\n", want[n]);
  printf("  got  %d\n", sum);
  exit(1);
}

/* ------------------------------------------------------------------------ */

#define WIDTH_DIGITS   (WIDTH * FIVEDEC_NUM_DIGITS)

/* Arbitrary size 100 for holding the cubes which attain max_sum.
   These are only option_verbose extra info.  */
static WORD max_at[100][WIDTH];
static int max_count;

void
check_final_cube (int n, WORD *cube, WORD *d1)
{
  /* one more cube += d1 should make cube >= 10^n */
  WORD c = fivedec_vector_addc (cube, cube,d1,WIDTH, 0);
  int good = (n == WIDTH_DIGITS
              ? c != 0
              : fivedec_vector_get_digit(cube,WIDTH, n) != 0);
  if (! good)
    error ("oops, next cube after final should be >= 10^n\n");
}

int
A373727 (int n)
{
  if (! (n >= 1))
    error ("A373727 is for n>=1\n");
  if (n > WIDTH_DIGITS)
    error ("n=%d exceeds WIDTH_DIGITS = %d\n  increase #define WIDTH to %d and re-compile\n", n, WIDTH_DIGITS, DIV_CEIL(n,FIVEDEC_NUM_DIGITS));

  uint64_t num_cubes = num_cubes_n_digits(n);
  if (option_verbose)
    printf("n=%d num_cubes %"PRIu64"\n", n, num_cubes);

  WORD cube[WIDTH], d1[WIDTH], d2[WIDTH];
  {
    mpz_t x, t;
    mpz_inits (x, t, NULL);

    /* x = floor(cube_root(10^(n-1)-1))
       is cube root of the cube < 10^(n-1),
       so first cube+=d1 makes first cube >= 10^(n-1)
    */
    mpz_ui_pow_ui (x, 10,n-1);
    mpz_sub_ui (x, x,1);
    mpz_root (x, x,3);
    DEBUG(gmp_printf("n=%d initial x = %Zd\n", n, t));

    /* cube = x^3 */
    mpz_pow_ui (t, x,3);
    fivedec_vector_from_mpz (cube,WIDTH, t);
    DEBUG (gmp_printf("  cube %Zd = ", t);
           fivedec_vector_print(cube,WIDTH));

    /* d1 = (x+1)*3*x + 1 */
    mpz_add_ui (t, x, 1);
    mpz_mul (t, t, x);
    mpz_mul_ui (t, t, 3);
    mpz_add_ui (t, t, 1);
    fivedec_vector_from_mpz (d1,WIDTH, t);

    /* d2 = 6*x + 6 */
    mpz_add_ui (t, x,1);
    mpz_mul_ui (t, t,6);
    fivedec_vector_from_mpz (d2,WIDTH, t);

    mpz_clears (x, t, NULL);
  }

  int max_sum = 0;
  for ( ; num_cubes != 0; num_cubes--) {
    int sum = 0;
    WORD cc=0;  /* carry in cube += d1 */
    WORD c1=0;  /* carry in d1 += d2   */
    int i;
    for (i=0; i < WIDTH; i++) {
      /* cube += d1 */
      WORD w;
      cc = fivedec_addc(&w, cube[i],d1[i], cc);
      cube[i] = w;
      sum += fivedec_sumdigits (w);

      /* d1 += d2 */
      c1 = fivedec_addc(&d1[i], d1[i],d2[i], c1);
    }
    assert (cc==0 && c1==0);

    /* d2 += 6 */
    c1 = fivedec_vector_addc_single (d2,WIDTH, 6);
    assert (c1 == 0);

    DEBUG (printf("n=%d sum %d from cube ", n, sum);
           fivedec_vector_print(cube,WIDTH));
    assert (fivedec_vector_digits_length(cube,WIDTH) == n);

    if (UNLIKELY(sum >= max_sum)) {
      if (option_verbose >= 2) {
        printf("n=%d %s high %d at cube ",
               n, sum==max_sum ? "equal" : "new", sum);
        fivedec_vector_print(cube,WIDTH);
      }
      if (sum > max_sum) {
        max_count = 0;
        max_sum = sum;
      }
      if (max_count < numberof(max_at))
        fivedec_vector_copy (max_at[max_count], cube, WIDTH);
      max_count++;
    }
  }

  check_final_cube(n,cube,d1);
  return (max_sum);
}

void
show_configuration (void)
{
  static int done = 0;
  if (done) return;
  done = 1;
  if (option_verbose) {
#ifdef __GNUC__
    printf("Have __GNUC__\n");
#else
    printf("Not __GNUC__\n");
#endif
    printf("WORD %zu bytes, %d bits\n", sizeof(WORD), WORD_BITS);
    printf("fivedec WORD %d digits in 5 bits each, space above %d bits\n",
           FIVEDEC_NUM_DIGITS,
           WORD_BITS - FIVEDEC_TOTAL_BITS);
    printf("WIDTH = %d words is WIDTH_DIGITS = %d\n",
           WIDTH, WIDTH_DIGITS);
    printf("assert()s %s\n\n",
           WANT_ASSERT ? "enabled" : "disabled");
  } else {
    assert (printf("assert()s enabled\n"));
  }
}

void
show (int n)
{
  show_configuration ();
  double start_time = cputime();
  int sum = A373727(n);
  double end_time = cputime();

  printf("n=%d max sum %d\n", n, sum);
  if (option_verbose) {
    printf("  at %d cube(s):\n", max_count);
    int i;
    for (i=0; i < MIN(max_count,numberof(max_at)); i++) {
      assert(numberof(max_at[i]) == WIDTH);
      assert(fivedec_vector_digits_length(max_at[i],WIDTH) == n);
      printf("    ");
      fivedec_vector_print(max_at[i],WIDTH);
    }
    if (max_count > numberof(max_at))
      printf("    ... (and more not recorded)\n");

    printf("  mean digit %.2lf, short of all nines by %d\n",
           (double)sum / n, 9*n-sum);

    double elapsed_time = end_time - start_time;
    double per_second = elapsed_time == 0 ? 0
      : num_cubes_n_digits(n) / elapsed_time;
    printf("  this n CPU time %.1lf seconds, %.1lf million cubes/second\n",
           elapsed_time, per_second / 1e6);
  }
  check_A373727_max_sum (n, sum);

  if (option_verbose)
    printf("\n");
}

int
main (int argc, char *argv[])
{
  int seen_n = 0;
  setbuf(stdout,NULL);
  DEBUG (option_verbose = 2);

  int i, n, nhi, endpos;
  for (i = 1; i < argc; i++) {
    const char *arg = argv[i];
    if (strcmp(arg,"-v")==0) {
      option_verbose++;
    } else if (sscanf(arg, "%d%n", &n, &endpos) == 1
               && endpos == strlen(arg)) {
      seen_n = 1;
      show(n);
    } else if (sscanf(arg, "%d-%d%n", &n, &nhi, &endpos) == 2
               && endpos == strlen(arg)) {
      seen_n = 1;
      for( ; n<=nhi; n++)
        show (n);
    } else {
      error("Unrecognised command line option: %s\n", arg);
    }
  }

  if (! seen_n)
    for(n=1; n<=10; n++)
      show (n);
  return(0);
}