math.hpp

View source code here on GitHub!

Includes

uintmax_t factorial(uint32_t n)

Warning

This function only works for numbers smaller than MAX_FACTORIAL_64 or MAX_FACTORIAL_128, depending on the size of uintmax_t.

uintmax_t n_choose_r(uint32_t n, uint32_t r)

Returns -1 if there is an overflow. Otherwise returns n choose r.

Warning

This function only works for numbers smaller than MAX_FACTORIAL_64 or MAX_FACTORIAL_128, depending on the size of uintmax_t.

 1#pragma once
 2
 3#include "macros.hpp"
 4#include <math.h>
 5#include <stdlib.h>
 6#include <stdint.h>
 7#include <inttypes.h>
 8
 9using namespace std;
10
11uintmax_t factorial(uint32_t n);
12inline uintmax_t factorial(uint32_t n) {
13    // note that this function only works for numbers smaller than MAX_FACTORIAL_64
14    if ((sizeof(uintmax_t) == 8 && n > MAX_FACTORIAL_64) || (sizeof(uintmax_t) == 16 && n > MAX_FACTORIAL_128))
15        return -1;
16    uintmax_t ret = 1;
17    for (uint32_t i = 2; i <= n; ++i)
18        ret *= i;
19    return ret;
20}
21
22uintmax_t n_choose_r(uint32_t n, uint32_t r) {
23    // function returns -1 if it overflows
24    if ((sizeof(uintmax_t) == 8 && n <= MAX_FACTORIAL_64) || (sizeof(uintmax_t) == 16 && n <= MAX_FACTORIAL_128))
25        return factorial(n) / factorial(r) / factorial(n-r);  // fast path if small enough
26    // slow path for larger numbers
27    int *factors;
28    uintmax_t answer, tmp;
29    uint32_t i, j;
30    factors = (int *) malloc(sizeof(int) * (n + 1));
31    // collect factors of final number
32    for (i = 2; i <= n; i++)
33        factors[i] = 1;
34    // negative factor values indicate need to divide
35    for (i = 2; i <= r; i++)
36        factors[i] -= 1;
37    for (i = 2; i <= n - r; i++)
38        factors[i] -= 1;
39    // this loop reduces to prime factors only
40    for (i = n; i > 1; i--)
41        for (j = 2; j < i; j++)
42            if (i % j == 0) {
43                factors[j] += factors[i];
44                factors[i / j] += factors[i];
45                factors[i] = 0;
46                break;
47            }
48    i = j = 2;
49    answer = 1;
50    while (i <= n) {
51        while (factors[i] > 0) {
52            tmp = answer;
53            answer *= i;
54            while (answer < tmp && j <= n) {
55                while (factors[j] < 0) {
56                    tmp /= j;
57                    factors[j]++;
58                }
59                j++;
60                answer = tmp * i;
61            }
62            if (answer < tmp)
63                return -1;  // this indicates an overflow
64            factors[i]--;
65        }
66        i++;
67    }
68    while (j <= n) {
69        while (factors[j] < 0) {
70            answer /= j;
71            factors[j]++;
72        }
73        j++;
74    }
75    free(factors);
76    return answer;
77}