/* Kevin Ryde, March 2024

   Usage: ./a.out [-v] [n] [n-n]...

   This is some C code using MPFR and GMP to calculate terms of

       A084420(n) = smallest k for which there are n powers of 10
         between k! and (k+1)!, exclusive.

   The method is an estimated k by how log10(k!) grows, then probe
   above or below for the actual result.  This is usually only a
   few mpfr_lgamma() and logs in about 5*n or 6*n bits of precision.


   Command Line
   ------------

   One or more individual n or ranges of n are specified by

       ./a.out 123 1-50

   Output is b-file style "n a(n)",

       ./a.out 6-8
       =>
       6 100502
       7 1000617
       8 10006509

   Optional parameter -v prints more verbose notes on the calculation.


   Target k by Delta
   -----------------

   The aim is to find

       k = smallest k with istep(k) = n
       where istep(x) = length(x+1) - length(x)
       and   length(x) = decimal digit length of x!
                       = floor(log10(x!)) + 1

   For example at n=1, a(1) = k = 3 is k! = 6 and then (k+1)! = 24
   is istep(k) = 1 digit longer.

   A fractional step can be considered from logs,

       fstep(x) = log10((x+1)!) - log10(x!)
                = log10(x+1)

   A lower bound for the target k is

       k >= s = 10^(n-1)
       since fstep(x) < n-1 for all x < s
       so cannot have istep(x) = n for any x < s

   For the target a(n) = k, each x in the range s <= x < k has
   istep(x) = n-1,

       x        =  s   s+1  s+2  ... k-1  k
       istep(x) = n-1  n-1  n-1  ... n-1  n

   Consequently lengths of s! and k! differ by

       length(k) - length(s) = (k-s)*(n-1)

   A delta can be formed to show when some length(x) is bigger than
   such "all steps n-1",

      delta(x) = length(x) - (length(s) + (x-s)*(n-1))

   The aim is then to find

       a(n) = smallest k for which delta(k) = 0
                             and   delta(k+1) = 1

   The benefit of this form is that a given x can be tested for
   whether its delta(x) implies x <= target k, or x > target k,
   allowing probing, bisecting, or similar.


   Integer Spanning
   ----------------

   The way a(n) = k spans powers of 10 can be illustrated in logs,

       L       L+1                 L+n
       *--------*--------...--------*----
              |                       |
           log10(k!)               log10((k+1)!)
              | --------------------> |
                      log10(k+1)

       where L = floor(log10(k!))

   log10(x) grows slowly so can expect a(n) = k has its log10(k!)
   only a small amount below integer L+1, and fstep(k) = log10(k+1)
   takes the next log10((k+1)!) to a small amount above integer L+n.


   k Estimate
   ----------

   The above spanning leads to an estimate for a(n) = k.
   The lower bound s has a certain fractional part f in its log10,

       s = 10^(n-1)
       f = frac(log10(s!)) = fractional part 0 <= f < 1

   Stepping from s to s+1 adds log10(s+1) which is

       log10(s+1) = n-1 + small amount

   This small amount pushes frac(log10((s+1)!)) to a little bigger
   than f from s!.  The idea is to estimate how many such small
   amounts will be needed to push the fractional part up to 1 and
   hence where the integer spanning shown above will happen.

   If j is small relative to s then log10(s+j) is roughly linear in j,

       log(1+x) = x/1 - x^2/2 + x^3/3 - ...   (Mercator)
       x = j/s and ignore terms x^2 and higher
       log10(s+j) = log10(s) + j/(s*log(10)) - ...
                  = n-1      + j/(s*log(10)) - ...

   A triangular sum of those is

          Sum  log10(s+j)   = i*(n-1) +  i*(i-1)/2 /(s*log(10)) - ...
       0 <= j < i

   i*(n-1) is an integer, so want the i*(i-1)/2 part to be

       1-f = i*(i-1)/2 /(s*log(10))
       i*(i-1) = s*(1-f)*log(100)

   i*(i-1) can be increased to i^2 since with s large compared to i
   doing so decreases the solution i by only about constant 1/2.  So

       i = sqrt( s*(1-f)*log(100) )
       k estimate = s + sqrt( s*(1-f)*log(100) )

   This estimate seems often quite close, either correct or 1 to 3
   away (in feasible size n).  But irrespective of how close,
   probing up or down then bisecting by the delta() condition can
   find the actual result.

   Estimate s + sqrt(s) is seen in the actual values.
   Terms like a(13) = 1000001981666 are power 10^(n-1) plus an
   amount near sqrt(10^(n-1)).


   Precision
   ---------

   An approximate bit length of factorial_length(x) can be found from

       x! = approx  sqrt(2*pi*x) * (x/e)^x      (Stirling)
       log2(log10(x!)) = approx  log2(x) + log2(log2(x))

   This is factorial_length_bitlength_estimate() in the code.
   factorial_length() requires at least this much precision so as
   to find the integer part of the length.

   A084420_mpz() first runs factorial_length() with additional
   precision of bitlength(s) in order to have that many bits in
   the "frac" part used by estimate_k().  Extra precision at this
   point is worthwhile since a good estimate k saves probing and
   bisecting steps.

   After this, A084420_mpz() only requires enough precision for the
   integer part of factorial_length().  But the quantities are
   close to integers and require additional precision to be sure
   which side of integer boundaries.

   How much additional precision seems to grow with n, but it's not
   clear quite how much.  (It'll depend on quite how close to
   integers the various log10(x!) fall.)  A084420_mpz() starts at
   1.5* the integer bit length.  factorial_length() increases on an
   as-needed basis until sure of the result.


   Things Not Done
   ---------------

   mpfr_lgamma() and mpfr_log() are called on exact integer x and 10
   first with round down then round up.  Exact rounding means those
   results differ by "1 ULP" and would have liked an
   mpfr_add_one_ulp() or an equivalent round down and up together.
*/

#define WANT_ASSERT  0
#define WANT_DEBUG   0

#if ! WANT_ASSERT
#define NDEBUG
#endif

#include <assert.h>
#include <math.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdnoreturn.h>
#include <string.h>
#include <unistd.h>

#include <gmp.h>
#include <mpfr.h>


/* ---------------------------------------------------------------------- */
/* Generic */

#if WANT_DEBUG
#define DEBUG(expr)  do { expr; } while (0)
#else
#define DEBUG(expr)  do {       } while (0)
#endif

#define numberof(array)  (sizeof(array)/sizeof((array)[0]))

#ifdef __GNUC__
#define ATTRIBUTE_PRINTF __attribute__ ((format (printf, 1, 2)))
#else
#define ATTRIBUTE_PRINTF
#endif

static noreturn ATTRIBUTE_PRINTF void
error (const char *format, ...)
{
  va_list ap;
  va_start (ap, format);
  vfprintf (stderr, format, ap);
  va_end (ap);
  exit(1);
}


/* ---------------------------------------------------------------------- */
/* sample data for consistency checks */

static const char * const A084420_data[] = {
  "0",  /* no n=0 */
  "3",  /* n=1 */
  "14",
  "103", "1042", "10158", "100502", "1000617", "10006509",
  "100019088", "1000004377", "10000170793", "100000442970",
  "1000001981666", "10000005339905"
};

/* Found a(n) = k, check it against A084420_data[]. */
void
check_data (int n, const mpz_t k)
{
  if (n >= numberof(A084420_data)) return;

  const char *want_str = A084420_data[n];
  mpz_t want;
  int conv = mpz_init_set_str (want, want_str, 10);
  if (conv != 0)
    error("oops, A084420_data[%d] bad string \"%s\"\n", n, want_str);
  if (mpz_cmp(k,want) != 0) {
    fprintf(stderr, "oops, n=%d\n", n);
    gmp_fprintf(stderr, "  got  %Zd\n", k);
    gmp_fprintf(stderr, "  want %Zd\n", want);
    exit(1);
  }
  mpz_clear(want);
}

/* ---------------------------------------------------------------------- */

int option_verbose = 0;
mpfr_t log10_recip_lo, log10_recip_hi;  /* 1/log(10) */

void
set_constants (void)
{
  mpfr_set_prec (log10_recip_lo, mpfr_get_default_prec());
  mpfr_set_prec (log10_recip_hi, mpfr_get_default_prec());

  /* log(10) to be denominators, so round opposite way */
  mpfr_log_ui (log10_recip_lo, 10, MPFR_RNDU);
  mpfr_log_ui (log10_recip_hi, 10, MPFR_RNDD);

  /* 1/log(10) */
  mpfr_ui_div (log10_recip_lo, 1, log10_recip_lo, MPFR_RNDD);
  mpfr_ui_div (log10_recip_hi, 1, log10_recip_hi, MPFR_RNDU);

  assert (mpfr_cmp (log10_recip_lo, log10_recip_hi) < 0);
}

/* operating in type double so as to detect overflow
   (even if normal mpfr_prec_t size makes that unlikely) */
void
set_precision (double pp)
{
  double fp = floor(pp);
  if (fp > MPFR_PREC_MAX)
    error("exceeded MPFR_PREC_MAX\n");
  mpfr_prec_t p = fp;
  if (option_verbose)
    mpfr_printf("# set precision %Pd bits\n", p);
  mpfr_set_default_prec (p);
  set_constants();
}

void
increase_precision (void)
{
  set_precision( mpfr_get_default_prec()*1.2 + 64 );
}

/* Return an estimated bit length of factorial_length(x),
   so log2(log10(x!)). */
double
factorial_length_bitlength_estimate (const mpz_t x)
{
  /* log2(x) + log2(log2(x)) */
  mpz_t L;
  mpz_init_set_ui (L, mpz_sizeinbase(x,2));
  mpz_add_ui (L, L, mpz_sizeinbase(L,2));
  double p = mpz_get_d(L);
  mpz_clear (L);
  return (p);
}

/* factorial_length(len,x) sets "len" to the decimal digit length of
   the factorial x!.  If "frac" is not NULL then set it to the
   fractional part of log10(x!) (but unspecified rounding there).

   The calculation is by mpfr_lgamma() in the MPFR default precision,
   and if that's not enough to be sure of the integer part then
   automatically increase_precision() until enough.
*/
void
factorial_length (mpz_t len, const mpz_t x, mpfr_t frac)
{
  assert (mpz_sgn (x) >= 0);  /* for x>=0 */
  assert (len != x);
  mpz_t len_hi;
  mpz_init (len_hi);

  for (;;) {
    mpfr_t lo, hi;
    mpfr_init_set_ui (lo, 1, MPFR_RNDD);
    mpfr_init_set_ui (hi, 1, MPFR_RNDU);
    mpfr_add_z (lo, lo,x, MPFR_RNDD);   /* x+1 */
    mpfr_add_z (hi, hi,x, MPFR_RNDU);

    int sign;
    mpfr_lgamma (lo, &sign, lo, MPFR_RNDD);  /* log(x!) */
    mpfr_lgamma (hi, &sign, hi, MPFR_RNDU);
    DEBUG(mpfr_printf ("lgamma %.10RDf\n", lo);
          mpfr_printf ("       %.10RUf\n", hi));

    mpfr_mul (lo, lo, log10_recip_lo, MPFR_RNDD);  /* log10(x!) */
    mpfr_mul (hi, hi, log10_recip_hi, MPFR_RNDU);

    /* if x==0 or x==1 then log10(x!) = 0 so exact lo==hi */
    assert(mpz_cmp_ui(x,1)<=0
           ? mpfr_cmp (lo, hi) == 0
           : mpfr_cmp (lo, hi) < 0);

    mpfr_get_z (len,    lo, MPFR_RNDD);  /* floor(log10(x!)) */
    mpfr_get_z (len_hi, hi, MPFR_RNDD);
    if (frac != NULL)
      mpfr_frac (frac, hi, MPFR_RNDD);

    mpfr_clear(lo);
    mpfr_clear(hi);

    DEBUG(gmp_printf("  len   = %Zd\n  len_hi = %Zd\n", len, len_hi));
    if (mpz_cmp (len,len_hi) == 0)
      break;  /* good, lo and hi agreed about integer part */

    increase_precision();
  }

  mpz_clear (len_hi);
  mpz_add_ui (len, len,1);  /* up from log10(x!) to digits length */
}

/* Set "k" to an estimate of a(n),
   using s = 10^(n-1) and f = frac(log10(s!)).
   The value in f is clobbered.  */
void
estimate_k (mpz_t k, const mpz_t s, mpfr_t f)
{
  /* s + sqrt( s*(1-f)*log(100) ) */
  mpfr_prec_t p = mpfr_get_prec(f);
  mpfr_t L;
  mpfr_init2 (L, p);
  mpfr_log_ui (L, 100, MPFR_RNDU);
  mpfr_ui_sub (f, 1,f, MPFR_RNDU);
  mpfr_mul (L, L,f, MPFR_RNDU);
  mpfr_mul_z (L, L,s, MPFR_RNDU);   /* s*(1-f)*log(100) */
  mpfr_sqrt (L, L, MPFR_RNDU);
  mpfr_get_z (k, L, MPFR_RNDU);
  if (mpz_sgn(k) < 0) mpz_set_ui (k,0);  /* estimate to be >= s */
  mpz_add (k, k,s);
  mpfr_clear(L);
}

/* Given s = 10^(n-1) and slen = factorial_length(s),
   return 0 if delta(k) = 0 or return 1 if delta(k) > 0.
   Must have k >= s and that means never delta(k) < 0.  */
int
delta_sign (int n, const mpz_t s, const mpz_t slen, const mpz_t k)
{
  assert (mpz_cmp(k,s) >= 0);
  mpz_t klen, delta;
  mpz_init (klen);
  mpz_init (delta);

  factorial_length (klen, k, NULL);

  /* delta = klen - (k-s)*(n-1) - slen */
  mpz_sub (delta, k,s);
  mpz_mul_ui (delta, delta,n-1);
  mpz_add (delta, delta, slen);
  mpz_sub (delta, klen, delta);
  DEBUG(gmp_printf("  slen %Zd klen %Zd delta %Zd\n",
                   slen, klen, delta));
  int sign = mpz_sgn (delta);

  mpz_clear (klen);
  mpz_clear (delta);

  assert (sign == 0 || sign == 1);  /* since k>=s */
  return (sign);
}


/* probe_for_delta() sets k to somewhere delta_sign(k) = target_sign
   by probing in steps away from "kmid".
   kmid must be the opposite sign: delta_sign(kmid) = 1 - target_sign.
   If target_sign = 0 then the probing is downwards, finding k < kmid.
   If target_sign = 1 then the probing is upwards, finding k > kmid.
   k_other must be initially the same value as kmid and if the probing
   sees other non-target signs then set k_other to the last of those
   (which means closest to the "k" found).  */
void
probe_for_delta (mpz_t k, mpz_t k_other,
                 int n, const mpz_t s, const mpz_t slen,
                 const mpz_t kmid, int target_sign)
{
  assert (target_sign == 0 || target_sign == 1);
  assert (mpz_cmp(kmid,k_other) == 0);
  mpz_t step;
  mpz_init_set_si (step, target_sign==0 ? -1 : 1);
  for (;;) {
    mpz_add (k, kmid,step);   /* try k = kmid + step */
    if (mpz_cmp(k,s) < 0) {
      mpz_set (k,s);
      break;
    }
    if (option_verbose)
      gmp_printf("# probe %s try %Zd\n",
                 target_sign==0 ? "down" : "up", k);
    int sign = delta_sign (n, s,slen, k);
    if (sign == target_sign) break;
    mpz_mul_2exp (step, step,1);   /* double */
  }
  mpz_clear (step);
}

/* Set k to A084420(n).  */
void
A084420_mpz (mpz_t k, int n)
{
  if (! (n>=1))
    error("A084420_mpz() is for n>=1 (got %d)\n", n);

  mpz_t s, slen, kmid, khi, step, klen;
  mpz_init(s);
  mpz_init(slen);
  mpz_init (kmid);
  mpz_init (khi);
  mpz_init (step);
  mpz_init (klen);

  /* s = 10^(n-1) */
  mpz_ui_pow_ui(s, 10,n-1);
  if (option_verbose)
    printf("# s = 10^(n-1) bit length %zu\n", mpz_sizeinbase(s,2));

  double slen_bitlength_estimate
    = factorial_length_bitlength_estimate(s);
  if (option_verbose)
    printf("# length(s!) bit length estimate %.0f\n",
           slen_bitlength_estimate);

  /* slen = factorial_length(s),
     kmid = estimated a(n) */
  {
    double s_bitlength_and_margin = mpz_sizeinbase(s,2) + + 64.0;
    set_precision (slen_bitlength_estimate + s_bitlength_and_margin);
    mpfr_t slen_frac;
    mpfr_init2 (slen_frac, s_bitlength_and_margin);
    factorial_length (slen, s, slen_frac);
    estimate_k (kmid, s, slen_frac);
    mpfr_clear (slen_frac);

    if (option_verbose) {
      gmp_printf("# length(s!) = %Zd\n", slen);
      gmp_printf("# k estimate = %Zd\n", kmid);
    }
  }

  set_precision (slen_bitlength_estimate * 1.5 + 64);
  {
    int sign = delta_sign (n, s,slen, kmid);
    if (option_verbose)
      printf("# target is %s estimate\n",
             sign==0 ? ">=" : "<");

    if (sign == 0) {
      /* kmid is lower bound so k, probe upwards for khi */
      mpz_set (k, kmid);
      probe_for_delta (khi,k, n,s,slen,kmid, 1);
    } else {
      /* kmid is upper bound so khi, probe downwards for k */
      mpz_set (khi, kmid);
      probe_for_delta (k,khi, n,s,slen,kmid, 0);
    }
  }

  /* Have delta(k) = 0 and delta(khi) > 0.
     Bisect to narrow to k+1 = khi with those same deltas */
  if (option_verbose) {
    mpz_sub (kmid, khi,k);
    gmp_printf ("# bisection lo,hi are %Zd apart\n", kmid);
  }
  for (;;) {
    assert (mpz_cmp(k,khi) <= 0);
    assert (delta_sign (n, s,slen, k) == 0);
    assert (delta_sign (n, s,slen, khi) == 1);

    mpz_add (kmid, k,khi);
    mpz_fdiv_q_2exp (kmid, kmid,1);
    DEBUG(gmp_printf("at k  =%Zd\n   khi=%Zd\n  kmid=%Zd\n",
                     k,khi,kmid));
    if (mpz_cmp (k, kmid) == 0) {
      assert ((mpz_sub_ui(khi,khi,1),
               mpz_cmp(k,khi)==0));
      break;
    }

    int sign = delta_sign (n, s,slen, kmid);
    DEBUG(gmp_printf("  delta sign %d\n", delta_sign));
    if (sign == 0) mpz_set (k, kmid);
    else           mpz_set (khi, kmid);
  }

  mpz_clear(s);
  mpz_clear(slen);
  mpz_clear (kmid);
  mpz_clear (khi);
  mpz_clear (step);
  mpz_clear (klen);
}

void
show_compile_info (void)
{
  static int done = 0;
  if (!done) {
    if (option_verbose) {
      printf("# mpfr_prec_t is %zu bytes\n", sizeof(mpfr_prec_t));
#if ! WANT_ASSERT
      printf("# assert()s not enabled\n");
#endif
    }
    assert (printf("# assert()s enabled\n"));
    done = 1;
  }
}

void
show (int n)
{
  show_compile_info();
  if (option_verbose) {
    printf("#\n");
    printf("# n=%d\n", n);
  }

  mpz_t k;
  mpz_init (k);
  A084420_mpz(k,n);
  gmp_printf("%d %Zd\n", n, k);
  check_data(n,k);
  mpz_clear(k);
}

int
main (int argc, char *argv[])
{
  setbuf(stdout,NULL);

  mpfr_init (log10_recip_lo);
  mpfr_init (log10_recip_hi);
  set_precision (64);

  int seen_n = 0;
  int i, n, nhi, endpos;
  for (i = 1; i < argc; i++) {
    const char *arg = argv[i];
    if (strcmp(arg,"-v")==0) {
      option_verbose++;
    } else if (sscanf(arg, "%d%n", &n, &endpos) == 1
               && endpos == strlen(arg)) {
      seen_n = 1;
      show(n);
    } else if (sscanf(arg, "%d-%d%n", &n, &nhi, &endpos) == 2
               && endpos == strlen(arg)) {
      seen_n = 1;
      for( ; n<=nhi; n++)
        show (n);
    } else {
      error("Unrecognised command line option: %s\n", arg);
    }
  }

  if (! seen_n)
    for(n=1; n<=10; n++)
      show (n);
  return(0);
}