Skip to content

Commit 0e0dc06

Browse files
williamfisetclaude
andauthored
Refactor ChineseRemainderTheorem: reuse PrimeFactorization, add tests (#1314)
* Refactor ChineseRemainderTheorem: reuse PrimeFactorization, add tests Remove duplicated primeFactorization, pollardRho, isPrime, and gcf methods in favor of the shared PrimeFactorization class. Simplify prime-power grouping loop and clean up Javadoc. Add 17 JUnit 5 tests covering eliminateCoefficient, crt, reduce, egcd, and full reduce+crt integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Keep CRT self-contained: inline prime factorization with Miller-Rabin Restore primeFactorization, pollardRho, isPrime, and supporting methods as private methods within ChineseRemainderTheorem rather than delegating to PrimeFactorization. Uses the improved Miller-Rabin primality test and overflow-safe mulMod instead of the original trial-division isPrime. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Simplify class Javadoc: remove @link references Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Expand class Javadoc with detailed CRT explanation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Remove UNTESTED note from Chinese remainder theorem in README Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3d3a885 commit 0e0dc06

File tree

4 files changed

+335
-95
lines changed

4 files changed

+335
-95
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch
239239

240240
# Mathematics
241241

242-
- [[UNTESTED] Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
242+
- [Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java)
243243
- [Prime number sieve (sieve of Eratosthenes)](src/main/java/com/williamfiset/algorithms/math/SieveOfEratosthenes.java) **- O(nlog(log(n)))**
244244
- [Prime number sieve (sieve of Eratosthenes, compressed)](src/main/java/com/williamfiset/algorithms/math/CompressedPrimeSieve.java) **- O(nlog(log(n)))**
245245
- [Totient function (phi function, relatively prime number count)](src/main/java/com/williamfiset/algorithms/math/EulerTotientFunction.java) **- O(√n)**
Lines changed: 156 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
/**
2-
* Use the chinese remainder theorem to solve a set of congruence equations.
2+
* Solve a set of congruence equations using the Chinese Remainder Theorem (CRT).
33
*
4-
* <p>The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod
5-
* m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the
6-
* coefficient. A value of null is returned if the coefficient cannot be eliminated.
4+
* Given a system of simultaneous congruences:
75
*
8-
* <p>The second method (reduce) is used to reduce a set of equations so that the moduli become
9-
* pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and
10-
* output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod
11-
* m_n−1). Note that the number of equations may change during this process. A value of null is
12-
* returned if the set of equations cannot be reduced to co-prime moduli.
6+
* x ≡ a_0 (mod m_0)
7+
* x ≡ a_1 (mod m_1)
8+
* ...
9+
* x ≡ a_{n-1} (mod m_{n-1})
1310
*
14-
* <p>The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
15-
* moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
16-
* m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
17-
* x≡a_new(mod m_new)x≡a_new(mod m_new).
11+
* where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a
12+
* unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}.
13+
*
14+
* The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i
15+
* is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term
16+
* contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum
17+
* satisfies every equation simultaneously.
18+
*
19+
* When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split
20+
* into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with
21+
* prime-power moduli. Redundant equations are removed and conflicting ones detected. After
22+
* reduction, the moduli are pairwise coprime and the standard CRT applies.
23+
*
24+
* The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through
25+
* by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the
26+
* modular inverse of the reduced coefficient.
1827
*
1928
* @author Micah Stairs
2029
*/
@@ -24,12 +33,16 @@
2433

2534
public class ChineseRemainderTheorem {
2635

27-
// eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
36+
/**
37+
* Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
38+
*
39+
* @return {a', m'} or null if unsolvable.
40+
*/
2841
public static long[] eliminateCoefficient(long c, long a, long m) {
29-
3042
long d = egcd(c, m)[0];
3143

32-
if (a % d != 0) return null;
44+
if (a % d != 0)
45+
return null;
3346

3447
c /= d;
3548
a /= d;
@@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
4255
return new long[] {a, m};
4356
}
4457

45-
// reduce() takes a set of equations and reduces them to an equivalent
46-
// set with pairwise co-prime moduli (or null if not solvable).
58+
/**
59+
* Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
60+
*
61+
* @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
62+
*/
4763
public static long[][] reduce(long[] a, long[] m) {
64+
List<Long> aNew = new ArrayList<>();
65+
List<Long> mNew = new ArrayList<>();
4866

49-
List<Long> aNew = new ArrayList<Long>();
50-
List<Long> mNew = new ArrayList<Long>();
51-
52-
// Split up each equation into prime factors
67+
// Split each modulus into prime-power factors
5368
for (int i = 0; i < a.length; i++) {
5469
List<Long> factors = primeFactorization(m[i]);
5570
Collections.sort(factors);
56-
ListIterator<Long> iterator = factors.listIterator();
57-
while (iterator.hasNext()) {
58-
long val = iterator.next();
59-
long total = val;
60-
while (iterator.hasNext()) {
61-
long nextVal = iterator.next();
62-
if (nextVal == val) {
63-
total *= val;
64-
} else {
65-
iterator.previous();
66-
break;
67-
}
71+
72+
int j = 0;
73+
while (j < factors.size()) {
74+
long p = factors.get(j);
75+
long pk = p;
76+
j++;
77+
while (j < factors.size() && factors.get(j) == p) {
78+
pk *= p;
79+
j++;
6880
}
69-
aNew.add(a[i] % total);
70-
mNew.add(total);
81+
aNew.add(a[i] % pk);
82+
mNew.add(pk);
7183
}
7284
}
7385

74-
// Throw away repeated information and look for conflicts
86+
// Remove redundant equations and detect conflicts
7587
for (int i = 0; i < aNew.size(); i++) {
7688
for (int j = i + 1; j < aNew.size(); j++) {
7789
if (mNew.get(i) % mNew.get(j) == 0 || mNew.get(j) % mNew.get(i) == 0) {
@@ -81,111 +93,161 @@ public static long[][] reduce(long[] a, long[] m) {
8193
mNew.remove(j);
8294
j--;
8395
continue;
84-
} else return null;
96+
} else
97+
return null;
8598
} else {
8699
if ((aNew.get(j) % mNew.get(i)) == aNew.get(i)) {
87100
aNew.remove(i);
88101
mNew.remove(i);
89102
i--;
90103
break;
91-
} else return null;
104+
} else
105+
return null;
92106
}
93107
}
94108
}
95109
}
96110

97-
// Put result into an array
98111
long[][] res = new long[2][aNew.size()];
99112
for (int i = 0; i < aNew.size(); i++) {
100113
res[0][i] = aNew.get(i);
101114
res[1][i] = mNew.get(i);
102115
}
103-
104116
return res;
105117
}
106118

119+
/**
120+
* Solves x ≡ a[i] (mod m[i]) assuming all moduli are pairwise coprime.
121+
*
122+
* @return {x, M} where M is the product of all moduli.
123+
*/
107124
public static long[] crt(long[] a, long[] m) {
108-
109125
long M = 1;
110-
for (int i = 0; i < m.length; i++) M *= m[i];
111-
112-
long[] inv = new long[a.length];
113-
for (int i = 0; i < inv.length; i++) inv[i] = egcd(M / m[i], m[i])[1];
126+
for (long mi : m)
127+
M *= mi;
114128

115129
long x = 0;
116130
for (int i = 0; i < m.length; i++) {
117-
x += (M / m[i]) * a[i] * inv[i]; // Overflow could occur here
131+
long Mi = M / m[i];
132+
long inv = egcd(Mi, m[i])[1];
133+
x += Mi * a[i] * inv;
118134
x = ((x % M) + M) % M;
119135
}
120136

121137
return new long[] {x, M};
122138
}
123139

124-
private static ArrayList<Long> primeFactorization(long n) {
125-
ArrayList<Long> factors = new ArrayList<Long>();
126-
if (n <= 0) throw new IllegalArgumentException();
127-
else if (n == 1) return factors;
128-
PriorityQueue<Long> divisorQueue = new PriorityQueue<Long>();
129-
divisorQueue.add(n);
130-
while (!divisorQueue.isEmpty()) {
131-
long divisor = divisorQueue.remove();
132-
if (isPrime(divisor)) {
133-
factors.add(divisor);
134-
continue;
135-
}
136-
long next_divisor = pollardRho(divisor);
137-
if (next_divisor == divisor) {
138-
divisorQueue.add(divisor);
139-
} else {
140-
divisorQueue.add(next_divisor);
141-
divisorQueue.add(divisor / next_divisor);
142-
}
143-
}
140+
/** Extended Euclidean algorithm. Returns {gcd(a,b), x, y} where ax + by = gcd(a,b). */
141+
static long[] egcd(long a, long b) {
142+
if (b == 0)
143+
return new long[] {a, 1, 0};
144+
long[] ret = egcd(b, a % b);
145+
long tmp = ret[1] - ret[2] * (a / b);
146+
ret[1] = ret[2];
147+
ret[2] = tmp;
148+
return ret;
149+
}
150+
151+
private static List<Long> primeFactorization(long n) {
152+
if (n <= 0)
153+
throw new IllegalArgumentException();
154+
List<Long> factors = new ArrayList<>();
155+
factor(n, factors);
144156
return factors;
145157
}
146158

159+
private static void factor(long n, List<Long> factors) {
160+
if (n == 1)
161+
return;
162+
if (isPrime(n)) {
163+
factors.add(n);
164+
return;
165+
}
166+
long d = pollardRho(n);
167+
factor(d, factors);
168+
factor(n / d, factors);
169+
}
170+
147171
private static long pollardRho(long n) {
148-
if (n % 2 == 0) return 2;
149-
// Get a number in the range [2, 10^6]
172+
if (n % 2 == 0)
173+
return 2;
150174
long x = 2 + (long) (999999 * Math.random());
151175
long c = 2 + (long) (999999 * Math.random());
152176
long y = x;
153177
long d = 1;
154178
while (d == 1) {
155-
x = (x * x + c) % n;
156-
y = (y * y + c) % n;
157-
y = (y * y + c) % n;
158-
d = gcf(Math.abs(x - y), n);
159-
if (d == n) break;
179+
x = mulMod(x, x, n) + c;
180+
if (x >= n)
181+
x -= n;
182+
y = mulMod(y, y, n) + c;
183+
if (y >= n)
184+
y -= n;
185+
y = mulMod(y, y, n) + c;
186+
if (y >= n)
187+
y -= n;
188+
d = gcd(Math.abs(x - y), n);
189+
if (d == n)
190+
break;
160191
}
161192
return d;
162193
}
163194

164-
// Extended euclidean algorithm
165-
private static long[] egcd(long a, long b) {
166-
if (b == 0) return new long[] {a, 1, 0};
167-
else {
168-
long[] ret = egcd(b, a % b);
169-
long tmp = ret[1] - ret[2] * (a / b);
170-
ret[1] = ret[2];
171-
ret[2] = tmp;
172-
return ret;
195+
/**
196+
* Deterministic Miller-Rabin primality test, correct for all long values. Uses 12 witnesses
197+
* guaranteed correct for n < 3.317 × 10^24.
198+
*/
199+
private static boolean isPrime(long n) {
200+
if (n < 2)
201+
return false;
202+
if (n < 4)
203+
return true;
204+
if (n % 2 == 0 || n % 3 == 0)
205+
return false;
206+
207+
long d = n - 1;
208+
int r = Long.numberOfTrailingZeros(d);
209+
d >>= r;
210+
211+
for (long a : new long[] {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
212+
if (a >= n)
213+
continue;
214+
long x = powMod(a, d, n);
215+
if (x == 1 || x == n - 1)
216+
continue;
217+
boolean composite = true;
218+
for (int i = 0; i < r - 1; i++) {
219+
x = mulMod(x, x, n);
220+
if (x == n - 1) {
221+
composite = false;
222+
break;
223+
}
224+
}
225+
if (composite)
226+
return false;
173227
}
228+
return true;
174229
}
175230

176-
private static long gcf(long a, long b) {
177-
return b == 0 ? a : gcf(b, a % b);
231+
private static long powMod(long base, long exp, long mod) {
232+
long result = 1;
233+
base %= mod;
234+
while (exp > 0) {
235+
if ((exp & 1) == 1)
236+
result = mulMod(result, base, mod);
237+
exp >>= 1;
238+
base = mulMod(base, base, mod);
239+
}
240+
return result;
178241
}
179242

180-
private static boolean isPrime(long n) {
181-
if (n < 2) return false;
182-
if (n == 2 || n == 3) return true;
183-
if (n % 2 == 0 || n % 3 == 0) return false;
184-
185-
int limit = (int) Math.sqrt(n);
186-
187-
for (int i = 5; i <= limit; i += 6) if (n % i == 0 || n % (i + 2) == 0) return false;
243+
private static long mulMod(long a, long b, long mod) {
244+
return java.math.BigInteger.valueOf(a)
245+
.multiply(java.math.BigInteger.valueOf(b))
246+
.mod(java.math.BigInteger.valueOf(mod))
247+
.longValue();
248+
}
188249

189-
return true;
250+
private static long gcd(long a, long b) {
251+
return b == 0 ? a : gcd(b, a % b);
190252
}
191253
}

src/test/java/com/williamfiset/algorithms/math/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,14 @@ java_test(
4747
runtime_deps = JUNIT5_RUNTIME_DEPS,
4848
deps = TEST_DEPS,
4949
)
50+
51+
# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest
52+
java_test(
53+
name = "ChineseRemainderTheoremTest",
54+
srcs = ["ChineseRemainderTheoremTest.java"],
55+
main_class = "org.junit.platform.console.ConsoleLauncher",
56+
use_testrunner = False,
57+
args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"],
58+
runtime_deps = JUNIT5_RUNTIME_DEPS,
59+
deps = TEST_DEPS,
60+
)

0 commit comments

Comments
 (0)