math.hpp
View source code here on GitHub!
Includes
-
uintmax_t factorial(uint32_t n)
Warning
This function only works for numbers smaller than
MAX_FACTORIAL_64
orMAX_FACTORIAL_128
, depending on the size ofuintmax_t
.
-
uintmax_t n_choose_r(uint32_t n, uint32_t r)
Returns -1 if there is an overflow. Otherwise returns n choose r.
Warning
This function only works for numbers smaller than
MAX_FACTORIAL_64
orMAX_FACTORIAL_128
, depending on the size ofuintmax_t
.
1#pragma once
2
3#include "macros.hpp"
4#include <math.h>
5#include <stdlib.h>
6#include <stdint.h>
7#include <inttypes.h>
8
9using namespace std;
10
11uintmax_t factorial(uint32_t n);
12inline uintmax_t factorial(uint32_t n) {
13 // note that this function only works for numbers smaller than MAX_FACTORIAL_64
14 if ((sizeof(uintmax_t) == 8 && n > MAX_FACTORIAL_64) || (sizeof(uintmax_t) == 16 && n > MAX_FACTORIAL_128))
15 return -1;
16 uintmax_t ret = 1;
17 for (uint32_t i = 2; i <= n; ++i)
18 ret *= i;
19 return ret;
20}
21
22uintmax_t n_choose_r(uint32_t n, uint32_t r) {
23 // function returns -1 if it overflows
24 if ((sizeof(uintmax_t) == 8 && n <= MAX_FACTORIAL_64) || (sizeof(uintmax_t) == 16 && n <= MAX_FACTORIAL_128))
25 return factorial(n) / factorial(r) / factorial(n-r); // fast path if small enough
26 // slow path for larger numbers
27 int *factors;
28 uintmax_t answer, tmp;
29 uint32_t i, j;
30 factors = (int *) malloc(sizeof(int) * (n + 1));
31 // collect factors of final number
32 for (i = 2; i <= n; i++)
33 factors[i] = 1;
34 // negative factor values indicate need to divide
35 for (i = 2; i <= r; i++)
36 factors[i] -= 1;
37 for (i = 2; i <= n - r; i++)
38 factors[i] -= 1;
39 // this loop reduces to prime factors only
40 for (i = n; i > 1; i--)
41 for (j = 2; j < i; j++)
42 if (i % j == 0) {
43 factors[j] += factors[i];
44 factors[i / j] += factors[i];
45 factors[i] = 0;
46 break;
47 }
48 i = j = 2;
49 answer = 1;
50 while (i <= n) {
51 while (factors[i] > 0) {
52 tmp = answer;
53 answer *= i;
54 while (answer < tmp && j <= n) {
55 while (factors[j] < 0) {
56 tmp /= j;
57 factors[j]++;
58 }
59 j++;
60 answer = tmp * i;
61 }
62 if (answer < tmp)
63 return -1; // this indicates an overflow
64 factors[i]--;
65 }
66 i++;
67 }
68 while (j <= n) {
69 while (factors[j] < 0) {
70 answer /= j;
71 factors[j]++;
72 }
73 j++;
74 }
75 free(factors);
76 return answer;
77}